Weird CUDA error

This is a weird cuda error that I get. If someone could just explain what it means or what is causing it, that would be awesome!

Traceback (most recent call last):
File “main.py”, line 40, in
train_correspondence_block(root_dir)
File “/home/jovyan/work/correspondence_block.py”, line 82, in train_correspondence_block
loss.backward()
File “/opt/venv/lib/python3.7/site-packages/torch/tensor.py”, line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/opt/venv/lib/python3.7/site-packages/torch/autograd/init.py”, line 100, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED (createCuDNNHandle at /pytorch/aten/src/ATen/cudnn/Handle.cpp:9)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f0c87e01536 in /opt/venv/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: + 0x10b29d8 (0x7f0c8930f9d8 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #2: at::native::getCudnnHandle() + 0xe54 (0x7f0c893111b4 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #3: + 0xf2bcfc (0x7f0c89188cfc in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #4: + 0xf2cd91 (0x7f0c89189d91 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #5: + 0xf30dcb (0x7f0c8918ddcb in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::native::cudnn_convolution_backward_input(c10::ArrayRef, at::Tensor const&, at::Tensor const&, c10::ArrayRef, c10::ArrayRef, c10::ArrayRef, long, bool, bool) + 0xb2 (0x7f0c8918e322 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #7: + 0xf97e40 (0x7f0c891f4e40 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #8: + 0xfdc6d8 (0x7f0c892396d8 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #9: at::native::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef, c10::ArrayRef, c10::ArrayRef, long, bool, bool, std::array<bool, 2ul>) + 0x4fa (0x7f0c8918f9ba in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #10: + 0xf9816b (0x7f0c891f516b in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #11: + 0xfdc734 (0x7f0c89239734 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #12: + 0x2c809b6 (0x7f0cc278f9b6 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #13: + 0x2cd0444 (0x7f0cc27df444 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #14: torch::autograd::generated::CudnnConvolutionBackward::apply(std::vector<at::Tensor, std::allocatorat::Tensor >&&) + 0x378 (0x7f0cc23a7918 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #15: + 0x2d89c05 (0x7f0cc2898c05 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::Engine::evaluate_function(std::shared_ptrtorch::autograd::GraphTask&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x16f3 (0x7f0cc2895f03 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #17: torch::autograd::Engine::thread_main(std::shared_ptrtorch::autograd::GraphTask const&, bool) + 0x3d2 (0x7f0cc2896ce2 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::Engine::thread_init(int) + 0x39 (0x7f0cc288f359 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #19: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7f0ccefce828 in /opt/venv/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #20: + 0xb9e6f (0x7f0cf1310e6f in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #21: + 0x74a4 (0x7f0cf8b804a4 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #22: clone + 0x3f (0x7f0cf81b7d0f in /lib/x86_64-linux-gnu/libc.so.6)

This is my train_correspondence_block function:

def train_correspondence_block(root_dir,epochs = 20): # Loop to train the correspondence block

# dataset for correspondence block
train_data_CB = OcclusionDataset(root_dir, classes = classes, 
                                transform = transforms.Compose([transforms.ToTensor()]))

batch_size = 4
num_workers = 0
valid_size = 0.2
# obtain training indices that will be used for validation
num_train = len(train_data_CB)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data_CB, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data_CB, batch_size=batch_size, 
    sampler=valid_sampler, num_workers=num_workers)

correspondence_block = UNET.UNet(n_channels = 3, out_channels_id = 9, 
                            out_channels_uv = 256, bilinear=True)

correspondence_block.cuda()
    
criterion_id = nn.CrossEntropyLoss()
criterion_u = nn.CrossEntropyLoss()
criterion_v = nn.CrossEntropyLoss()

# specify optimizer
optimizer = optim.Adam(correspondence_block.parameters(), lr=3e-4,weight_decay=3e-5)

# Training Loop
# number of epochs to train the model
n_epochs = epochs
valid_loss_min = np.Inf # track change in validation loss

for epoch in range(1, n_epochs+1):
    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    
    ###################
    # train the model #
    ###################
    correspondence_block.train()
    for image, idmask,umask,vmask in train_loader:
        # move tensors to GPU 
        image, idmask,umask,vmask = image.cuda(), idmask.cuda(), umask.cuda(), vmask.cuda()
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        idmask_pred,umask_pred,vmask_pred = correspondence_block(image)       
        # calculate the batch loss
        loss_id = criterion_id(idmask_pred, idmask)
        loss_u = criterion_u(umask_pred, umask)
        loss_v = criterion_v(vmask_pred, vmask)
        loss = loss_id + loss_u + loss_v
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()


    ######################    
    # validate the model #
    ######################
    correspondence_block.eval()
    for image, idmask,umask,vmask in valid_loader:       
        # move tensors to GPU 
        image, idmask,umask,vmask = image.cuda(), idmask.cuda(), umask.cuda(), vmask.cuda()
        # forward pass: compute predicted outputs by passing inputs to the model
        idmask_pred,umask_pred,vmask_pred = correspondence_block(image)
        # calculate the batch loss
        loss_id = criterion_id(idmask_pred, idmask)
        loss_u = criterion_u(umask_pred, umask)
        loss_v = criterion_v(vmask_pred, vmask)
        loss = loss_id + loss_u + loss_v
        # update average validation loss 
        valid_loss += loss.item()
    
    # calculate average losses
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
        
    # print training/validation statistics 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(correspondence_block.state_dict(), 'correspondence_block.pt')
        valid_loss_min = valid_loss

@ptrblck - Any help would be really appreciated sir in helping me understand what the error is.

Was the code working before, i.e. did you change anything in the code or your setup (CUDA, cudnn version)?
Could you rerun the code with CUDA_LAUNCH_BLOCKING=1 python script.py args and post the stack trace here again, please?
Also, which PyTorch, CUDA, cudnn versions are you using and how did you install these?

Yes it was working before but this error started showing up when CUDA 10.2 was released.
GPU capability couldn’t be used with 10.2 so I switched to 10.1 and installed using:
pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
Pytorch and torch versions:
torch==1.5.0+cu101 torchvision==0.6.0+cu101

Stack trace with CUDA_LAUNCH_BLOCKING=1:

THCudaCheck FAIL file=/pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu line=134 error=700 : an illegal memory access was encountered
Traceback (most recent call last):
  File "main.py", line 40, in <module>
    train_correspondence_block(root_dir)
  File "/home/jovyan/work/correspondence_block.py", line 77, in train_correspondence_block
    loss_id = criterion_id(idmask_pred, idmask)
  File "/opt/venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/venv/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 932, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/opt/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 2317, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/opt/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 2117, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: cuda runtime error (700) : an illegal memory access was encountered at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:134

Thanks for the update.
Could you check the target min and max value before passing it to nn.CrossEntropyLoss or nn.NLLLoss?
The expected values are in the range [0, nb_classes-1], while you might pass out of bounds values.

You could also run the code on the CPU, which should give you a better stack trace.

Thanks Patrick. You were correct.
Target value was out of bounds and running on CPU gave a much better stack trace.
This is a custom dataset that I created on my own. Look at the stack trace below, if you could tell me if the error is with the data or the model architecture that would be amazing:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-12-de1f45d3ba4c> in <module>
     22         idmask_pred,umask_pred,vmask_pred = correspondence_block(image)
     23         # calculate the batch loss
---> 24         loss_id = criterion_id(idmask_pred, idmask)
     25         loss_u = criterion_u(umask_pred, umask)
     26         loss_v = criterion_v(vmask_pred, vmask)

/opt/venv/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/venv/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    930     def forward(self, input, target):
    931         return F.cross_entropy(input, target, weight=self.weight,
--> 932                                ignore_index=self.ignore_index, reduction=self.reduction)
    933 
    934 

/opt/venv/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2315     if size_average is not None or reduce is not None:
   2316         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2317     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2318 
   2319 

/opt/venv/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2115         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2116     elif dim == 4:
-> 2117         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2118     else:
   2119         # dim == 3 or dim > 4

The stack trace and the posted code don’t show the valid ranges, but nn.CrossEntropyLoss expects a model output of the shape [batch_size, nb_classes, height, width] and a target of [batch_size, height, width] containing the class indices in the range [0, nb_classes-1] for a multi-class segmentation use case.

1 Like

Yes the input ids for the ground truth mask are for some reason out of range.
Thanks for helping me figure out what the error was :slight_smile: