WHCSRL 技术网

解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘

项目场景:

sketch rnn


问题描述:

完整报错:

  1. File "/root/DiffusionModel/Pytorch-Sketch-RNN-master/sketch_rnn.py", line 416, in sample_bivariate_normal
  2. x = np.random.multivariate_normal(mean, cov, 1)
  3. File "mtrand.pyx", line 4114, in numpy.random.mtrand.RandomState.multivariate_normal
  4. TypeError: ufunc 'add' output (typecode 'O') could not be coerced to provided output parameter (typecode 'd') according to the casting rule ''same_kind''

报错代码段:

  1. def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):
  2. if greedy:
  3. return mu_x, mu_y
  4. mean = [mu_x, mu_y]
  5. sigma_x *= np.sqrt(hp.temperature)
  6. sigma_y *= np.sqrt(hp.temperature)
  7. cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
  8. [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
  9. x = np.random.multivariate_normal(mean, cov, 1)
  10. return x[0][0], x[0][1]

问题出在这一行:

x = np.random.multivariate_normal(mean, cov, 1)

原因分析:

np.random.multivariate_normal函数中使用了np.add操作,所以找到的资料大多是在np.add()中添加参数casting='unsafe'

但是multivariate_normal这个函数中是没有casting参数的。


解决方案:

把张量取出来运算即可。

  1. def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):
  2. mu_x = mu_x.item()
  3. mu_y = mu_y.item()
  4. sigma_x = sigma_x.item()
  5. sigma_y = sigma_y.item()
  6. if greedy:
  7. return mu_x, mu_y
  8. mean = [mu_x, mu_y]
  9. sigma_x *= np.sqrt(hp.temperature)
  10. sigma_y *= np.sqrt(hp.temperature)
  11. cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
  12. [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
  13. x = np.random.multivariate_normal(mean, cov, 1)
  14. return x[0][0], x[0][1]

推荐阅读