CrossEntropy loss function error - target 3 is out of bounds

Hey There!

I’m running a semantic segmentation model (unet) on input images of shape (3, 256, 256) and masks of shape (256,256) with pixel values of 0,1,2, and 3 (3 classes total, with pixel 0 being the background).

When retrieving the output from the model, I have:
output from model = (b_s, 3, 256, 256)
predicted mask = (b_s, 256, 256)

I use a cross-entropy loss function, but receive the following error:

   1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1839     elif dim == 4:
-> 1840         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1841     else:
   1842         # dim == 3 or dim > 4

IndexError: Target 3 is out of bounds.

This is the model I am using: source

> class DoubleConv(nn.Module):
>     """(convolution => [BN] => ReLU) * 2"""
> 
>     def __init__(self, in_channels, out_channels, mid_channels=None):
>         super().__init__()
>         if not mid_channels:
>             mid_channels = out_channels
>         self.double_conv = nn.Sequential(
>             nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
>             nn.BatchNorm2d(mid_channels),
>             nn.ReLU(inplace=True),
>             nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
>             nn.BatchNorm2d(out_channels),
>             nn.ReLU(inplace=True)
>         )
> 
>     def forward(self, x):
>         return self.double_conv(x)
> 
> 
> class Down(nn.Module):
>     """Downscaling with maxpool then double conv"""
> 
>     def __init__(self, in_channels, out_channels):
>         super().__init__()
>         self.maxpool_conv = nn.Sequential(
>             nn.MaxPool2d(2),
>             DoubleConv(in_channels, out_channels)
>         )
> 
>     def forward(self, x):
>         return self.maxpool_conv(x)
> 
> 
> class Up(nn.Module):
>     """Upscaling then double conv"""
> 
>     def __init__(self, in_channels, out_channels, bilinear=True):
>         super().__init__()
> 
>         # if bilinear, use the normal convolutions to reduce the number of channels
>         if bilinear:
>             self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
>             self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
>         else:
>             self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
>             self.conv = DoubleConv(in_channels, out_channels)
> 
> 
>     def forward(self, x1, x2):
>         x1 = self.up(x1)
>         # input is CHW
>         diffY = x2.size()[2] - x1.size()[2]
>         diffX = x2.size()[3] - x1.size()[3]
> 
>         x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
>                         diffY // 2, diffY - diffY // 2])
>         # if you have padding issues, see
>         # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
>         # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
>         x = torch.cat([x2, x1], dim=1)
>         return self.conv(x)
> 
> 
> class OutConv(nn.Module):
>     def __init__(self, in_channels, out_channels):
>         super(OutConv, self).__init__()
>         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
> 
>     def forward(self, x):
>         return self.conv(x)
> 
> class UNet(nn.Module):
>     def __init__(self, n_channels, n_classes, bilinear=True):
>         super(UNet, self).__init__()
>         self.n_channels = n_channels
>         self.n_classes = n_classes
>         self.bilinear = bilinear
> 
>         self.inc = DoubleConv(n_channels, 64)
>         self.down1 = Down(64, 128)
>         self.down2 = Down(128, 256)
>         self.down3 = Down(256, 512)
>         factor = 2 if bilinear else 1
>         self.down4 = Down(512, 1024 // factor)
>         self.up1 = Up(1024, 512 // factor, bilinear)
>         self.up2 = Up(512, 256 // factor, bilinear)
>         self.up3 = Up(256, 128 // factor, bilinear)
>         self.up4 = Up(128, 64, bilinear)
>         self.outc = OutConv(64, n_classes)
> 
>     def forward(self, x):
>         x1 = self.inc(x.float())
>         x2 = self.down1(x1)
>         x3 = self.down2(x2)
>         x4 = self.down3(x3)
>         x5 = self.down4(x4)
>         x = self.up1(x5, x4)
>         x = self.up2(x, x3)
>         x = self.up3(x, x2)
>         x = self.up4(x, x1)
>         logits = self.outc(x)
>        return logits

Initializing model:

model = UNet(n_channels=3, n_classes=3, bilinear=False)

optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

criterion = nn.CrossEntropyLoss()

The interesting thing is that when I initialize my model to n_classes = 4,5 and so on, the loop runs just fine, but the outputs are terrible. It’s when I set the n_classes = 3 when it all breaks. Any help would be greatly appreciated :slight_smile:

Hi Ben!

The short story is that you have four classes, not three – “background”
counts as one of your classes.

Your model output has shape [b_s, n_classes = 3, 256, 256]. So,
yes, your model is predicting (only) three classes.

CrossEntropyLoss sees that its input (your model output) has
n_classes = 3, so it will require that your target only has values
for three classes. That is, your target values must be integer class
labels running from [0, n_classes - 1], i.e., be in (0, 1, 2).

But your target has values in (0, 1, 2, 3). When CrossEntropyLoss
hits the “fourth” class value of 3, it throws the error you see.

It sounds like your problem really does have four classes, three
foreground classes plus one background class. So leave your target
as it is, but run your model with n_classes = 4 so that its output
has shape [b_s, 4, 256, 256].

Note, you might want to check whether your dataset is unbalanced.
It is not unusual to have many more background pixels than foreground.
(I’m not saying you do, but you might.) If so, you might consider using
CrossEntropyLoss’s weight argument to reweight the classes in your
loss calculation.

Best.

K. Frank

1 Like

Thank you very much! This was very helpful.

Given that I have much more background pixels than classes, I would initialize cross entropy loss as:

criterion = nn.CrossEntropyLoss(weight = tensor([0.3, 1, 1, 1]) to reduce the weight value of the background?

Thanks again! Life saver

Hi Ben!

This could be reasonable. You are weighting each of the foreground
pixels equally, and weighting the background pixels about one third as
much as any one foreground class.

So this would make sense if each of your three foreground classes
appeared about equally often, and each of your foreground classes
individually appeared about one third as often as your background
class.

The typical advice is to weight your classes in about inverse proportion
to their frequency of occurrence.

Having said that, a three-to-one difference in frequency of occurrence
doesn’t strike me as being highly unbalanced. In such a case I might
expect using class weights to help at the margins, but I doubt doing so
would be a game changer.

Best.

K. Frank