How to use torch.distributions.multivariate_normal.MultivariateNormal in multi-gpu mode

In single gpu mode,MultivariateNormal can run correctly, but when i switch to multi-gpu mode, always get the error:

G = torch.exp(m.log_prob(Delta))
File “xxxxx”, line 210, in log_prob
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
File “xxxxx”, line 57, in _batch_mahalanobis
M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)

If use gpu MultivariateNormal must specify device, so I specify the device in forward function:
the code is:

Delta # is a random tensor
mean = torch.zeros(2).to(x.device)
cov = torch.eye(2).to(x.device)
m = MultivariateNormal(mean, cov * self.sigma**2)
G = torch.exp(m.log_prob(Delta))

I would be very grateful if you could give some suggestions

Hi Skyler,

I encountered the same problem. Did you figure out the solution?