The cost function for semantic segmentation?


I’m trying to understand the process of semantic segmentation and I’m having trouble at the loss function. For simple classification networks the loss function is usually a 1 dimensional tenor having size equal to the number of classes, but for semantic segmentation the target is also an image.

I have an input image of the shape: Inputs: torch.Size([1, 3, 224, 224]) which produces an output of shape: Outout: torch.Size([1, 32, 224, 224]). The target on the other hand is of the shape: Targets: torch.Size([1, 1, 360, 480]).

There is a size mismatch, does that matter?

I tried with the loss function: criterion =torch.nn.BCEWithLogitsLoss() but with this PyCharm just stalls and crashes.

The example I am following used CrossEntropyLoss2D() as shown here:

but when I use that I get an error with a warning that NLLLoss2d has been deprecated. Furthermore, there is no 2D loss function listed here in the documentation:

I tried resizing the target making it the same as the output but that didint work as well.

I would like to read more about the loss function for semantic segmentation but couldnt find much help. Why am I having trouble with the loss functions, is it because of the size mismatch? since they are both images, do I need to write my own class to handle this?

Many thanks for any help/guidance.

1 Like

Based on the output shape it looks like you have 32 different classes.
Your target shape, i.e. the segmentation mask, should have the shape [batch_size, 224, 224], and should contain the class indices as its values.

The spatial size mismatch between the target mask and the model output does matter, as you are trying to calculate the pixel-wise loss, i.e. each pixel prediction corresponds to the pixel target class.

You don’t have to use the *2d loss functions, as the vanilla loss functions now can take multi-dimensional tensors.

Here is a small dummy example for a segmentation use case:

batch_size = 1
c, h, w = 3, 10, 10
nb_classes = 5

x = torch.randn(batch_size, c, h, w)
target = torch.empty(batch_size, h, w, dtype=torch.long).random_(nb_classes)

model = nn.Sequential(
    nn.Conv2d(c, 6, 3, 1, 1),
    nn.Conv2d(6, nb_classes, 3, 1, 1)

criterion = nn.CrossEntropyLoss()

output = model(x)
loss = criterion(output, target)

That is very helpful. many thanks :slight_smile: