RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

I am training my models from Google Collab with batch_size = 128 after 1 epoch it has this problem. I don’t know have to fix it with the same batch_size (reduce batch_size to 32 can avoid this problem). Here is Colab spec: driver Version: 460.32.03 CUDA Version: 11.2
You can find my notebook here.
Thanks for your help.

3 Likes

It seems that one of your operands is too large to fit in int32 (or negative, but that seems unlikely).

I thought that recent PyTorch will give a better error (but don’t work around it):

import torch
LARGE = 2**31+1
for i, j, k in [(1, 1, LARGE), (1, LARGE, 1), (LARGE, 1, 1)]:
    inp = torch.randn(i, k, device="cuda", dtype=torch.half)
    weight = torch.randn(j, k, device="cuda", dtype=torch.half)
    try:
        torch.nn.functional.linear(inp, weight)
    except RuntimeError as e:
        print(e)
    del inp
    del weight
at::cuda::blas::gemm<float> argument k must be non-negative and less than 2147483647 but got 2147483649
at::cuda::blas::gemm<float> argument m must be non-negative and less than 2147483647 but got 2147483649
at::cuda::blas::gemm<float> argument n must be non-negative and less than 2147483647 but got 2147483649

But they don’t work around it. (It needs a lot of memory to trigger the bug…)

Maybe you can get a credible backtrace and record the input shapes to the operation that fails.

Best regards

Thomas

So what can I do to solve this problem, I just know to change batch size to smaller.

In order of difficulty:

  • make batch size smaller,
  • make a minimal reproducing example (i.e. just two or three inputs from torch.random and the call to the torch.nn.functional.linear) and file a bug,
  • hot-patch torch.nn.functional.linear with a workaround (splitting the operation into multiple linear or matmul calls),
  • submit a PR with a fix in PyTorch and discuss whether you can add a test or whether it’d take a prohibitive large amount of GPU memory to run (or hire someone to do so).

Best regards

Thomas

3 Likes

Thank for your help.

For the peoples getting this error and ending up on this post, please know that it can also be caused if you have a mismatch between the dimension of your input tensor and the dimensions of your nn.Linear module. (ex. x.shape = (a, b) and nn.Linear(c, c, bias=False) with c not matching)

It is a bit sad that pytorch don’t give a more explicit error messages.

42 Likes

@Jeremy_Cochoy This was really helpful. Solved my issue.

2 Likes

@Jeremy_Cochoy Thanks for your comments!

@Jeremy_Cochoy Thanks!

Hello @Jeremy_Cochoy
I have added an nn.Linear(512,10) layer to my model and the shape of the input that goes into this layer is torch.Size([32,512,1,1]). I have tried reducing the batch size from 128 to 64 and now to 32, but each of these gives me the same error.
Any idea what could be going wrong?

1 Like

I think you want to transpose the dimensions of your input tensor before and after (Linear — PyTorch 1.9.0 documentation say it expect a Nx*xC_in tensor and you give him a 32x…x1 tensor)

Something like linear(x.transpose(1,3)).transpose(1,3) ?

Thanks a lot also solved my Issue!

I got the same error because of a mismatch of the input dimensions in the first layer.

Thanks for the hint!

helped me! thank you!

hello all,

I had this problem while I was using a smaller batch size (=4) for testing some code changes, while my initial batch size was 64. I checked the shapes for nn.Linear and they matched

After 1 hour I found that the only change was the batch size. By increasing batch size back to 64 everything worked perfectly. Pytorch version 1.8.1. Not sure why this error is caused.

Hope it helps!

1 Like

Thanks a lot! Solved my issue.

Hello all,

I am using pytorch ‘1.13.0+cu117’, my env is NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0

in the terminal of python, I tried the very simple example:

>>> import torch
>>> x=torch.ones(2,2,1).to('cuda')
>>> y=torch.ones(2,1,2).to('cuda')
>>> x
tensor([[[1.],
         [1.]],

        [[1.],
         [1.]]], device='cuda:0')
>>> y
tensor([[[1., 1.]],

        [[1., 1.]]], device='cuda:0')
>>> y@x
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`
>>> torch.bmm(y,x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`
>>> torch.matmul(y,x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`
>>> x=torch.ones(2,1).to('cuda')
>>> y=torch.ones(1,2).to('cuda')
>>> y@x
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
>>> torch.mm(y,x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
>>> torch.mm(x,y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
>>>

The issues are obviously not caused by the mismatch size. Anyone has any idea? thanks!

Could you post the output of python -m torch.utils.collect_env, please, as I cannot reproduce the error in 1.13.0+cu117 on a 3090.

Hello! I had the same issue as wtliao

This os my post of python -m torch.utils.collect_env:

PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.17

Python version: 3.8.11 (default, Sep 1 2021, 12:33:46) [GCC 9.3.1 20200408 (Red Hat 9.3.1-2)] (64-bit runtime)
Python platform: Linux-3.10.0-1160.42.2.el7.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: 11.4.120
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.2.4
/usr/lib64/libcudnn_adv_infer.so.8.2.4
/usr/lib64/libcudnn_adv_train.so.8.2.4
/usr/lib64/libcudnn_cnn_infer.so.8.2.4
/usr/lib64/libcudnn_cnn_train.so.8.2.4
/usr/lib64/libcudnn_ops_infer.so.8.2.4
/usr/lib64/libcudnn_ops_train.so.8.2.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.13.0
[pip3] torchaudio==0.13.0
[pip3] torchcam==0.3.2
[pip3] torchvision==0.14.0
[conda] Could not collect

Also this error occurs while running

import torch.nn.functional as F
import torch
a = torch.rand((1, 2, 3)).to('cuda')
b = torch.rand((1, 3, 24, 94)).to('cuda')
grid = F.affine_grid(a, b.size())

File ~/.venv/default/lib64/python3.8/site-packages/torch/nn/functional.py:4332, in affine_grid(theta, size, align_corners)
4329 elif min(size) <= 0:
4330 raise ValueError(“Expected non-zero, positive output size. Got {}”.format(size))
→ 4332 return torch.affine_grid_generator(theta, size, align_corners)

RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

I cannot reproduce the issue on a T4 with torch==1.13.0+cu117:

import torch
torch.cuda.get_device_name(0)
# 'Tesla T4'
torch.__version__
# '1.13.0+cu117'

import torch.nn.functional as F
import torch
a = torch.rand((1, 2, 3)).to('cuda')
b = torch.rand((1, 3, 24, 94)).to('cuda')
grid = F.affine_grid(a, b.size())

print(grid)
tensor([[[[0.4507, 0.2959],
          [0.4582, 0.3064],
          [0.4656, 0.3169],
          ...,
          [1.1288, 1.2493],
          [1.1363, 1.2597],
          [1.1437, 1.2702]],

         [[0.4635, 0.3081],
          [0.4710, 0.3186],
          [0.4784, 0.3291],
          ...,
          [1.1417, 1.2615],
          [1.1491, 1.2719],
          [1.1566, 1.2824]],

         [[0.4763, 0.3203],
          [0.4838, 0.3308],
          [0.4913, 0.3413],
          ...,
          [1.1545, 1.2736],
          [1.1619, 1.2841],
          [1.1694, 1.2946]],
...