How to use torch.distributions.multivariate_normal.MultivariateNormal in multi-gpu mode

In single gpu mode,MultivariateNormal can run correctly, but when i switch to multi-gpu mode, always get the error:

G = torch.exp(m.log_prob(Delta))
File “xxxxx”, line 210, in log_prob
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
File “xxxxx”, line 57, in _batch_mahalanobis
M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)

If use gpu MultivariateNormal must specify device, so I specify the device in forward function:
the code is:

Delta # is a random tensor
mean = torch.zeros(2).to(x.device)
cov = torch.eye(2).to(x.device)
m = MultivariateNormal(mean, cov * self.sigma**2)
G = torch.exp(m.log_prob(Delta))

I would be very grateful if you could give some suggestions

Hi Skyler,

I encountered the same problem. Did you figure out the solution?

Thanks,
Julie

I came into this problem recently with CUDA version 11.5 and driver version 495.29.05 and it seems that MultivariateNormal on GPU cannot deal with covariance matrix’s shapes correctly. For example, the code below don’t work:

num = 524288
loc = torch.rand(num, 12, device="cuda:0")
cov = torch.diag(torch.ones(12, device="cuda:0"))
actions = torch.rand(num, 12, device="cuda:0")
dist = MultivariateNormal(loc, cov)

dist.log_prob(actions)

the code above will report error as below:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "xxxxxx/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py", line 210, in log_prob
    M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
  File "xxxxxx/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py", line 57, in _batch_mahalanobis
    M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2)  # shape = b x c
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)`

The problem solved when I specify the tensor mat 's shape manually so that loc and mat have the same batch_size:

num = 524288
loc = torch.rand(num, 12, device="cuda:0")
cov = torch.diag(torch.ones(12, device="cuda:0")).unsqueeze(0).repeat(num, 1, 1)
actions = torch.rand(num, 12, device="cuda:0")
dist = MultivariateNormal(loc, cov)

dist.log_prob(actions)

It might be the same problem as the one that you encountered.