torch.distributions.MultivariateNormal.log_prob throws RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED for large batch sizes

I am using torch.distributions.MultivariateNormal with large batches since I am working with long sequential data. For large batches, a runtime error from CUDA is thrown when computing the log_prob. Here is a snippet to reproduce the issue.

import torch
cov = torch.eye(3).to('cuda:1')
means = torch.randn(540000,3).to('cuda:1')
distribs = torch.distributions.MultivariateNormal(means, cov)

obs = torch.randn(540000,3).to('cuda:1')

log_prob = distribs.log_prob(obs)

This is the error message:

Traceback (most recent call last):
File “/data2/users/cb221/ar-hmm/ar-hmm/code_snippet.py”, line 11, in
log_prob = distribs.log_prob(obs)
File “/data2/packages/anaconda3/envs/py38_pytorch/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py”, line 216, in log_prob
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
File “/data2/packages/anaconda3/envs/py38_pytorch/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py”, line 59, in _batch_mahalanobis
M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) # shape = b x c
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)

For context, this is the ouput from python3 -m torch.utils.collect_env.

Collecting environment information…
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.16 | packaged by conda-forge | (default, Feb 1 2023, 16:01:55) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti
GPU 2: NVIDIA GeForce RTX 2080 Ti
GPU 3: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 515.105.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==1.13.1
[conda] numpy 1.24.2 pypi_0 pypi
[conda] torch 1.13.1 pypi_0 pypi

Note: I am not experiencing similar issues when using smaller batch sizes. The current solution I am using is to segment the batch into smaller pieces, run it iteratively, and then concatenate (which is obviously slower).

Thanks for forwarding the issue! I can reproduce it in 2.0.0+cu118 and we’ll look into it.

1 Like

A similar issue was reported upstream: Triangular solve fails on batches of matrices of size > (*, 524281) · Issue #97211 · pytorch/pytorch (github.com) and fixed in cuBLAS which should hopefully be landing in a newer version soon.

2 Likes

I got the same error on the latest Pytorch+CUDA, but it ran fine for me on Pytorch 1.7.1+CUDA toolkit 11.0.

A similar problem was reported here. The fix there was to repeat the covariance matrices.