Regarding the use of ignore index in cross entropy loss

I am trying to understand how ignore_index works with the cross entropy loss. So I am working with a segmentation problem and if the all the segmentation values are -100 , I dont want it to propagate the loss as the segmentation doesn’t not exist for that specific case.
So I just tested out the code

import torch.nn as nn
import torch

target_tensor = torch.ones(960,960)*-100
target_tensor.requires_grad = True
input_tensor = torch.randn(960, 960, requires_grad=True)

loss = nn.CrossEntropyLoss(ignore_index=-100,reduction='mean')
output = loss(input_tensor, target_tensor )
output.backward()

Now when I see what the values for output is it shows : -707187.4375 and the according to the docs the expected input gradient should be 0 since everything is -100 but it shows some values

print(input_tensor.grad)
tensor([[-0.0546,  0.1607, -0.0571,  ...,  0.0920,  0.0534, -0.1826],
        [ 0.1276,  0.1603,  0.1651,  ...,  0.0619,  0.1603, -0.2230],
        [ 0.1149, -0.0944,  0.1460,  ...,  0.1027,  0.1783, -0.0054],
        ...,
        [ 0.0059,  0.1577,  0.1724,  ...,  0.1459,  0.1738,  0.1727],
        [ 0.1814,  0.1275, -0.2737,  ..., -0.2048, -0.0896,  0.1874],
        [ 0.0066,  0.0401, -0.6887,  ...,  0.0292,  0.1217,  0.1145]])

Can anyone tell me whats wrong or show me some example where ignore index is used and the loss is not propagated.

Your example might have a mis-match in shape between the target and input. nn.CrossEntropyLoss consolidates the channel/class dimension of the input via softmax.

import torch
import torch.nn as nn

B, C, H, W = 1, 5, 6, 6

target_tensor = torch.ones((B, H, W))
target_tensor[:, :3, :] = -100
input_tensor = torch.randn((B, C, H, W), requires_grad=True)

loss_f = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
loss = loss_f(input_tensor, target_tensor.long())

print(loss)
loss.backward()

print(input_tensor.grad)
1 Like

Thanks , I followed ur method but since I am working with medical images I have only 1 channel and the segmentation size is 960*960 so using those parameters when I tested it

import torch
import torch.nn as nn

B, C, H, W = 16, 1, 960, 960
target_tensor = torch.ones((B, H, W))
target_tensor[:, 3:, :] = -100
input_tensor = torch.randn((B, C, H, W), requires_grad=True)

loss_f = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
loss = loss_f(input_tensor, target_tensor.long())

print(loss)
loss.backward()

print(input_tensor.grad)

I get the following error


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [191], in <cell line: 10>()
      7 input_tensor = torch.randn((B, C, H, W), requires_grad=True)
      9 loss_f = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
---> 10 loss = loss_f(input_tensor, target_tensor.long())
     12 print(loss)
     13 loss.backward()

File ~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
   1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174     return F.cross_entropy(input, target, weight=self.weight,
   1175                            ignore_index=self.ignore_index, reduction=self.reduction,
   1176                            label_smoothing=self.label_smoothing)

File ~/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py:3026, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3024 if size_average is not None or reduce is not None:
   3025     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 1 is out of bounds.

In the case of a single class, cross entropy with softmax activation is incorrect as the softmax activation considers multiple classes.

nn.BCEWithLogitsLoss uses sigmoid activation and will work with a single class. It does not support an ignore index, but I’ve included some code to set the loss to 0 at the locations of your choice with the mask variable.

import torch
import torch.nn as nn


B, C, H, W = 1, 1, 6, 6

target_tensor = torch.ones((B, C, H, W))
mask = torch.ones((B, C, H, W))
mask[:, :, :3, :] = 0
input_tensor = torch.randn((B, C, H, W), requires_grad=True)

loss_f = nn.BCEWithLogitsLoss(reduction='none')

loss = (loss_f(input_tensor, target_tensor) * mask).mean()
print(loss)

loss.backward()

print(input_tensor.grad)
1 Like

Thanks , this makes sense .I figured I had a mistake on my end and I could use nn.BCEWithLogitsLoss like u suggested or I could treat the binary segmentation as a multi-class segmentation use case with 2 classes.

For this approach the model would return output logits in the shape [batch_size, 2, height, width], the target would have the shape [batch_size, height, width] and contain the class indices [0, 1].
Then I would be able to use cross entropy loss.
Thanks for the quick response.