CUDA error : device-side assert triggered

I got stuck with cuda error and I don’t know why this is suddenly happening.
help me fix this problem please.

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)
print(device)
[out] cuda
#model
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.conv_mid = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),   
                                     nn.Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder4 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder3 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder1 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.upconv4 = nn.ConvTranspose2d(1024, 1024, kernel_size=(2, 2), stride=2)
        
        self.upconv3 = nn.ConvTranspose2d(512, 512, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv2 = nn.ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv1 = nn.ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
        
        self.conv1x1_out = nn.Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        out = self.encoder1(x)
        out = self.pool1(out)
        out = self.encoder2(out)
        out = self.pool2(out)
        out = self.encoder3(out)
        out = self.pool3(out)
        out = self.encoder4(out)
        out = self.pool4(out)
        out = self.conv_mid(out)
        out = self.upconv4(out)
        out = self.decoder4(out)
        out = self.upconv3(out)
        out = self.decoder3(out)
        out = self.upconv2(out)
        out = self.decoder2(out)
        out = self.upconv1(out)
        out = self.decoder1(out)
        out = self.conv1x1_out(out)
        return out
    
model = unet().to(device)
[out] 
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-43-810580fbbd02> in <module>()
    104         return out
    105 
--> 106 model = unet().to(device)

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in convert(t)
    421 
    422         def convert(t):
--> 423             return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
    424 
    425         return self._apply(convert)

RuntimeError: CUDA error: device-side assert triggered

Setting CUDA_LAUNCH_BLOCKING=1, you would get a better traceback.

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

I just used this code above and still can’t solve it.
do I need to change my loss function?

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-885ac444824b> in <module>()
     22 
     23         pred_mask = model(mri)
---> 24         loss = loss_func(pred_mask, true_mask)
     25 
     26         optimizer.zero_grad()

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    914     def forward(self, input, target):
    915         return F.cross_entropy(input, target, weight=self.weight,
--> 916                                ignore_index=self.ignore_index, reduction=self.reduction)
    917 
    918 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2019     if size_average is not None or reduce is not None:
   2020         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2021     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2022 
   2023 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1839     elif dim == 4:
-> 1840         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1841     else:
   1842         # dim == 3 or dim > 4

RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:127

When device = 'cpu', does your code work well?

no it doesn’t
does this error something to do with cuda error?

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-23-f939b650666d> in <module>()
     22 
     23         pred_mask = model(mri)
---> 24         loss = loss_func(pred_mask, true_mask)
     25 
     26         optimizer.zero_grad()

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/usr/local/lib/python3.5/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    914     def forward(self, input, target):
    915         return F.cross_entropy(input, target, weight=self.weight,
--> 916                                ignore_index=self.ignore_index, reduction=self.reduction)
    917 
    918 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2019     if size_average is not None or reduce is not None:
   2020         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2021     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2022 
   2023 

/usr/local/lib/python3.5/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1839     elif dim == 4:
-> 1840         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1841     else:
   1842         # dim == 3 or dim > 4

IndexError: Target 255 is out of bounds.

edit

I found that when I use one image data and mask data each, it works well. I think while my whole dataset goes through the loss function, error occurs

In fact, your given code has no problem of model initialization and moving the model to GPU for me. I found your error included in this post.

I have no idea why moving the model to gpu is not working for me. I don’t think the model is wrong. the model itself works good on cpu(there’s still an error because of the loss function, which is not about the model architecture.), but when I move it to gpu the error occurs. I have tried it on google colab, result was same.

I read the post you gave me, but I don’t think it is wrong with the label and output size. since my work is about image segmentation. I only need two label, background and the target.

The error points to the criterion, so you should check the model output and target for valid shapes and values.
I.e. for a vanilla multi-class classification, the target should have the shape [batch_size] and contain values in the range [0, nb_classes-1].

This error occurs when using inappropriate targets:

>>> import torch
>>> loss_fn = torch.nn.CrossEntropyLoss()
>>> input = torch.randn(3, 2, 24, 24)
>>> target = torch.empty(3, 24, 24, dtype=torch.long).random_(2) * 255
>>> loss = loss_fn(input, target)
IndexError: Target 255 is out of bounds.