Loss function for segmentation models

In that case nn.NLLLoss should work fine.
Sorry for missing the shape information, but your mask shouldn’t have the channel dimension, but store the class indices instead.
Here is a small example code:

batch_size = 1    
nb_classes = 3
h, w = 24, 24
output = torch.randn(batch_size, nb_classes, h, w)
target = torch.empty(batch_size, h, w, dtype=torch.long).random_(nb_classes)
criterion = nn.NLLLoss()
loss = criterion(output, target)
2 Likes