RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)` while running fine on the CPU

I have a model for classification that works on CPU, but when I try to run the model on GPU using DataParallel I get the following error:

Traceback (most recent call last):
  File "Main.py", line 58, in <module>
    trainer.train(dl_train=train_loader, dl_validation=validation_loader)
  File "/home/dsi/davidsr/AttentionProj/Trainers.py", line 56, in train
    loss.backward()
  File "/home/dsi/davidsr/.local/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/dsi/davidsr/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`

I checked that the dimensions of the prediction and the ground truth are the same but I can’t seem to find why it raises this error.

I run my code with CUDA_LAUNCH_BLOCKING=1 and got:

davidsr@dgx02:~/AttentionProj$ CUDA_LAUNCH_BLOCKING=1, CUDA_VISIBLE_DEVICES=3 python3 Main.py --n_epochs 2 --lr 0.0001 --new_split 0 --mode train --par 0
Traceback (most recent call last):
  File "Main.py", line 58, in <module>
    trainer.train(dl_train=train_loader, dl_validation=validation_loader)
  File "/home/dsi/davidsr/AttentionProj/Trainers.py", line 55, in train
    loss = self.w_bce(y_hat, t)
  File "/home/dsi/davidsr/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dsi/davidsr/AttentionProj/Losses.py", line 27, in forward
    pos = torch.matmul(torch.matmul(t, self.pos_w), torch.log(y_hat + self.eps))
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy)`

It appears my error is in the loss function so I added the implementation of it:

class WeightedBCE(nn.Module):
    def __init__(self, pos_w, neg_w):
        super(WeightedBCE, self).__init__()
        self.pos_w = torch.tensor(pos_w, dtype=torch.float, requires_grad=False)
        self.neg_w = torch.tensor(neg_w, dtype=torch.float, requires_grad=False)
        self.eps = 1e-10
        return

    def forward(self, y_hat, t):
        pos = torch.matmul(torch.matmul(t, self.pos_w), torch.log(y_hat + self.eps))
        neg = torch.matmul(torch.matmul(1 - t, self.pos_w), torch.log(1 - y_hat + self.eps))
        return torch.mean(pos + neg)

Hi,
Thanks for the report.
Could you give some details about:

  • Which GPU, cuda version and pytorch version you’re using?
  • What are the y_hat and t Tensor you give as input. In particular could you share for both t.size(), t.stride() and t.storage_offset()?

cc @ptrblck do you know what could be causing this?

It could be an OOM issue triggering the cublas error, but it might also be an internal cublas issue, so the setup information would be helpful to reproduce this issue on our side.

pytorch version 1.7.0
GPU - Tesla V100
NVIDIA-SMI 440.100 Driver Version: 440.100 CUDA Version: 10.2

Regarding the values you wanted for y_hat and t, the prints was made from within the forward in the loss function:


y_hat
torch.Size([12, 14])
(14, 1)
0

t
torch.Size([12, 14])
(14, 1)
0

Anything else you need?

Thanks! Could you also share the size of pos_w and neg_w? I though they were scalars but it does not run when they are set as scalars.

Yes, they are the size of the expected output (t or y_hat) it start as a numpy array of shape (1,14) and then transform it to torch.tensor.

torch.Size([1,14])

And they are constants

Hi,

I think there is still something I’m missing about sizes. Here is my attempt to repro but the shape don’t match. Could you give me an updated version will all the right shapes please? Thanks!

import torch
from torch import nn

class WeightedBCE(nn.Module):
    def __init__(self, pos_w, neg_w):
        super(WeightedBCE, self).__init__()
        pos_w = torch.tensor(pos_w, dtype=torch.float, requires_grad=False)
        neg_w = torch.tensor(neg_w, dtype=torch.float, requires_grad=False)
        self.register_buffer("pos_w", pos_w)
        self.register_buffer("neg_w", neg_w)
        self.eps = 1e-10
        return

    def forward(self, y_hat, t):
        pos = torch.matmul(torch.matmul(t, self.pos_w), torch.log(y_hat + self.eps))
        neg = torch.matmul(torch.matmul(1 - t, self.pos_w), torch.log(1 - y_hat + self.eps))
        return torch.mean(pos + neg)


y_hat = torch.rand(12, 14, device="cuda", requires_grad=True)
t = torch.rand(12, 14, device="cuda", requires_grad=True)

pos_w = torch.rand(1, 14).numpy()
neg_w = torch.rand(1, 14).numpy()
mod = WeightedBCE(pos_w, neg_w).cuda()

mod(y_hat, t)

You got the sizes right. This is what so weird.

I don’t really get how I can run it on my CPU but on the GPU it throws the error

The sizes are unfortunately not right, as the code still raises:

RuntimeError: mat1 dim 1 must match mat2 dim 0

so could you please double check the shapes and/or post an executable code snippet, so that we could debug it?

I run the code on the CPU, and this are the variables with the sizes:

y_hat size is: torch.Size([12, 14])
y_hat values:

tensor([[0.6129, 0.5309, 0.5493, 0.5381, 0.3087, 0.5583, 0.6137, 0.5149, 0.4085,
         0.5862, 0.2384, 0.5680, 0.5260, 0.4991],
        [0.6569, 0.5384, 0.5118, 0.5537, 0.2853, 0.5693, 0.5910, 0.5591, 0.4307,
         0.6214, 0.2627, 0.5559, 0.4875, 0.5189],
        [0.6661, 0.5641, 0.5279, 0.5512, 0.2661, 0.5696, 0.6380, 0.5634, 0.4121,
         0.6507, 0.2291, 0.5484, 0.5269, 0.5167],
        [0.6080, 0.5319, 0.5410, 0.5360, 0.3337, 0.5400, 0.6016, 0.5272, 0.4206,
         0.5621, 0.2836, 0.5702, 0.5462, 0.5181],
        [0.6449, 0.5319, 0.5020, 0.5396, 0.3185, 0.5530, 0.6014, 0.5320, 0.4364,
         0.5910, 0.2769, 0.5569, 0.5150, 0.5478],
        [0.6354, 0.5295, 0.5082, 0.5238, 0.3286, 0.5463, 0.6044, 0.5155, 0.4336,
         0.5611, 0.2828, 0.5773, 0.5416, 0.5526],
        [0.6253, 0.5377, 0.5210, 0.5412, 0.3121, 0.5338, 0.5989, 0.5410, 0.4344,
         0.5828, 0.2790, 0.5386, 0.4860, 0.5109],
        [0.5994, 0.5170, 0.5295, 0.5254, 0.3505, 0.5416, 0.5829, 0.5294, 0.4342,
         0.5438, 0.2987, 0.5563, 0.5430, 0.5301],
        [0.7066, 0.6006, 0.5521, 0.5716, 0.2426, 0.5623, 0.7090, 0.5899, 0.3890,
         0.7019, 0.1959, 0.5618, 0.5408, 0.5389],
        [0.6898, 0.5965, 0.5376, 0.5712, 0.2483, 0.5569, 0.6715, 0.5582, 0.4050,
         0.6706, 0.2231, 0.5759, 0.5098, 0.5386],
        [0.7319, 0.6521, 0.5477, 0.6104, 0.1955, 0.5622, 0.7457, 0.5822, 0.3798,
         0.7330, 0.1674, 0.6016, 0.5055, 0.5282],
        [0.6339, 0.5636, 0.5512, 0.5806, 0.3190, 0.5315, 0.5829, 0.5866, 0.4368,
         0.5918, 0.2504, 0.5495, 0.5461, 0.4653]], grad_fn=<SigmoidBackward>)

t size is: torch.Size([12, 14])
t values:

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

pos_w size is: torch.Size([14])
pos_w values:

tensor([0.7754, 0.9461, 0.9093, 0.9552, 0.7428, 0.9514, 0.9681, 0.9956, 0.6159,
        0.8887, 0.8778, 0.9344, 0.9717, 0.8984])

neg_w size is: torch.Size([14])
neg_w values:

tensor([0.2246, 0.0539, 0.0907, 0.0448, 0.2572, 0.0486, 0.0319, 0.0044, 0.3841,
        0.1113, 0.1222, 0.0656, 0.0283, 0.1016])

I was indeed mistaken and I apologize for that.

Ho right!
So the updated code below runs!
But at least on the P100 of colab, it does not raise any error :confused:

@ptrblck would you have a V100 handy to test this?

import torch
from torch import nn

class WeightedBCE(nn.Module):
    def __init__(self, pos_w, neg_w):
        super(WeightedBCE, self).__init__()
        pos_w = torch.tensor(pos_w, dtype=torch.float, requires_grad=False)
        neg_w = torch.tensor(neg_w, dtype=torch.float, requires_grad=False)
        self.register_buffer("pos_w", pos_w)
        self.register_buffer("neg_w", neg_w)
        self.eps = 1e-10
        return

    def forward(self, y_hat, t):
        pos = torch.matmul(torch.matmul(t, self.pos_w), torch.log(y_hat + self.eps))
        neg = torch.matmul(torch.matmul(1 - t, self.pos_w), torch.log(1 - y_hat + self.eps))
        return torch.mean(pos + neg)


y_hat = torch.rand(12, 14, device="cuda", requires_grad=True)
t = torch.rand(12, 14, device="cuda", requires_grad=True)

pos_w = torch.rand(14).numpy()
neg_w = torch.rand(14).numpy()
mod = WeightedBCE(pos_w, neg_w).cuda()

mod(y_hat, t)

I also tested it on a machine with TITAN RTX and it also works.

1 Like

@albanD Sure, I can test it.

The code runs fine on a machine using a V100 DGXs-16GB (driver 440.33.01) and V100-SXM3-32GB (driver 450.51.06) using the conda PyTorch binaries for 1.7.0 and 1.7.1 with the CUDA runtime 10.2.

Hi, my error message is as follows,

Current run is terminating due to exception: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
Engine run is terminating due to exception: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
Traceback (most recent call last):
  File "train.py", line 474, in <module>
    main(**kwargs)
  File "train.py", line 301, in main
    trainer.run(train_loader, epochs)
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 702, in run
    return self._internal_run()
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 775, in _internal_run
    self._handle_exception(e)
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 469, in _handle_exception
    raise e
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 745, in _internal_run
    time_taken = self._run_once_on_dataset()
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 850, in _run_once_on_dataset
    self._handle_exception(e)
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 469, in _handle_exception
    raise e
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/ignite/engine/engine.py", line 833, in _run_once_on_dataset
    self.state.output = self._process_function(self, self.state.batch)
  File "train.py", line 165, in step
    losses["total_loss"].backward()
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/xxx/miniconda3/envs/xxx/lib/python3.6/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`

Even I used the debug flag CUDA_LAUNCH_BLOCKING=1, the output message did not change.
The code is working well under pytorch=1.3.0, cudatoolkit=10.0 on RTX 2080Ti.
Does this have sth to do with PyTorch-Ignite?

Btw, my config is,

PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 3080
Nvidia driver version: 460.73.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] pytorch-ignite==0.4.4
[pip3] torch==1.8.1
[pip3] torchaudio==0.8.0a0+e4e171a
[pip3] torchvision==0.9.1
[conda] blas                      2.109                       mkl    conda-forge
[conda] blas-devel                3.9.0                     9_mkl    conda-forge
[conda] cudatoolkit               11.1.74              h6bb024c_0    nvidia
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] ignite                    0.4.4                      py_0    pytorch
[conda] libblas                   3.9.0                     9_mkl    conda-forge
[conda] libcblas                  3.9.0                     9_mkl    conda-forge
[conda] liblapack                 3.9.0                     9_mkl    conda-forge
[conda] liblapacke                3.9.0                     9_mkl    conda-forge
[conda] mkl                       2021.2.0           h726a3e6_389    conda-forge
[conda] mkl-devel                 2021.2.0           ha770c72_390    conda-forge
[conda] mkl-include               2021.2.0           h726a3e6_389    conda-forge
[conda] numpy                     1.19.5           py36h2aa4a07_1    conda-forge
[conda] pytorch                   1.8.1           py3.6_cuda11.1_cudnn8.0.5_0    pytorch
[conda] torchaudio                0.8.1                      py36    pytorch
[conda] torchvision               0.9.1                py36_cu111    pytorch

Thanks for any help.

This error often points towards an out of memory issue, since the cublas handle cannot be created, so you could try to reduce the memory footprint, e.g. by lowering the batch size.

1 Like

Hi ptrblck,

Thanks for your really instant reply. Yes, it is the exact problem!