RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR on GPU

Hello, I’m trying to use a deeplabv3_resnet50 from torchvision.models.segmentation. If I use the complete dataset (that has 1246 training patches and 297 val patches 256x256 with 6 channels) when I lunch the training function on GPU, I have the following error (on CPU it works). I’ve tried different batch sizes (from 4 to 32) and it seems not to be a memory error… @ptrblck Can you help me?

12:06:52 - epoch: 1 - batch: 0 - loss: 2.338578224182129
12:07:03 - epoch: 1 - batch: 10 - loss: 1.7042455673217773
12:07:14 - epoch: 1 - batch: 20 - loss: 1.710103988647461
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [768,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [769,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [770,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [771,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [772,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [773,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [512,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [513,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [514,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [515,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [516,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [517,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [256,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [257,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [258,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [259,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [260,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [261,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [4,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [5,0,0] Assertion `t >= 0 && t < n_classes` failed.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_28353/2565883664.py in <cell line: 2>()
      1 optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)
----> 2 train_loop(50, train_dataloader, valid_dataloader, model, loss, optimizer, 
      3            acc_fns=[oa], batch_tfms=tfms)

/tmp/ipykernel_28353/2022702472.py in train_loop(epochs, train_dl, val_dl, model, loss_fn, optimizer, acc_fns, batch_tfms)
     25             #print(f"X: {X.shape}")
     26             #print(f"y: {y.shape}")
---> 27             pred = cuda_model(X)['out']
     28             loss = loss_fn(pred, y)
     29 

~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/python3/lib/python3.10/site-packages/torchvision/models/segmentation/_utils.py in forward(self, x)
     21         input_shape = x.shape[-2:]
     22         # contract: features is a dict of tensors
---> 23         features = self.backbone(x)
     24 
     25         result = OrderedDict()

~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/python3/lib/python3.10/site-packages/torchvision/models/_utils.py in forward(self, x)
     67         out = OrderedDict()
     68         for name, module in self.items():
---> 69             x = module(x)
     70             if name in self.return_layers:
     71                 out_name = self.return_layers[name]

~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/nn/modules/conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    457                             weight, bias, self.stride,
    458                             _pair(0), self.dilation, self.groups)
--> 459         return F.conv2d(input, weight, bias, self.stride,
    460                         self.padding, self.dilation, self.groups)
    461 

RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

I’m on AWS Sagemaker with the following configuration:

import torch
import torch.nn as nn
​
print(torch.cuda.get_device_name())
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
​
Tesla T4
1.13.1+cu116
11.6
8400

Driver:

!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2020 NVIDIA Corporation Built on Wed_Jul_22_19:09:09_PDT_2020 Cuda compilation tools, release 11.0, V11.0.221 Build cuda_11.0_bu.TC445_37.28845127_0

Thanks for your attention.

The cuDNN error seems to be misleading and might be triggered by e.g. this error.
The stacktrace points to an invalid target index in nn.NL:LLoss or nn.CrossEntropyLoss:

../aten/src/ATen/native/cuda/NLLLoss2d.cu:104: nll_loss2d_forward_kernel: block: [8,0,0], thread: [768,0,0] Assertion `t >= 0 && t < n_classes` failed.

so make sure the target contains class indices in the range [0, nb_classes-1].

Thanks, finally I’ve found that the problem was related to an invalid label target that pass my check control. Have a nice day.