I am using torch.distributions.MultivariateNormal with large batches since I am working with long sequential data. For large batches, a runtime error from CUDA is thrown when computing the log_prob. Here is a snippet to reproduce the issue.
import torch
cov = torch.eye(3).to('cuda:1')
means = torch.randn(540000,3).to('cuda:1')
distribs = torch.distributions.MultivariateNormal(means, cov)
obs = torch.randn(540000,3).to('cuda:1')
log_prob = distribs.log_prob(obs)
This is the error message:
Traceback (most recent call last):
File “/data2/users/cb221/ar-hmm/ar-hmm/code_snippet.py”, line 11, in
log_prob = distribs.log_prob(obs)
File “/data2/packages/anaconda3/envs/py38_pytorch/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py”, line 216, in log_prob
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
File “/data2/packages/anaconda3/envs/py38_pytorch/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py”, line 59, in _batch_mahalanobis
M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) # shape = b x c
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)
For context, this is the ouput from python3 -m torch.utils.collect_env
.
Collecting environment information…
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.8.16 | packaged by conda-forge | (default, Feb 1 2023, 16:01:55) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti
GPU 2: NVIDIA GeForce RTX 2080 Ti
GPU 3: NVIDIA GeForce RTX 2080 Ti
Nvidia driver version: 515.105.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==1.13.1
[conda] numpy 1.24.2 pypi_0 pypi
[conda] torch 1.13.1 pypi_0 pypi
Note: I am not experiencing similar issues when using smaller batch sizes. The current solution I am using is to segment the batch into smaller pieces, run it iteratively, and then concatenate (which is obviously slower).