Device side assert triggered while printing loss

I launched the code with CUDA_LAUNCH_BLOCKING=1 and the stack trace is as follows -

<ipython-input-1-0a71a9a41b52> in <module>
   1137         print(mask.shape)
   1138         with autocast():
-> 1139             loss = model((image, mask))
   1140             print(loss)
   1141 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-1-0a71a9a41b52> in forward(self, input)
    553             out = self.segaHead(out)
    554 
--> 555             focal = self.focalLoss(out, mask)
    556             dice = dice_loss(mask, out)
    557 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-1-0a71a9a41b52> in forward(self, input, target)
     64 
     65         # compute the negative likelyhood
---> 66         logpt = -F.cross_entropy(input, target)
     67         pt = torch.exp(logpt)
     68 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
-> 2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2265     elif dim == 4:
   2266         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: cuda runtime error (710) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1603729138878/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:115

and my focal loss definition is -


class FocalLoss2d(nn.Module):

    def __init__(self, gamma=0, weight=None, size_average=True):
        super(FocalLoss2d, self).__init__()

        self.gamma = gamma
        self.weight = weight
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.contiguous().view(input.size(0), input.size(1), -1)
            input = input.transpose(1,2)
            input = input.contiguous().view(-1, input.size(2)).squeeze()
        if target.dim()==4:
            target = target.contiguous().view(target.size(0), target.size(1), -1)
            target = target.transpose(1,2)
            target = target.contiguous().view(-1, target.size(2)).squeeze()
        elif target.dim()==3:
            target = target.view(-1)
        else:
            target = target.view(-1, 1)

        # compute the negative likelyhood
        logpt = -F.cross_entropy(input, target)
        pt = torch.exp(logpt)

        # compute the loss
        loss = -((1-pt)**self.gamma) * logpt

        # averaging (or not) loss
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

What exactly is throwing the error - ?
Added this as a comment so as to not make original post too long