Loss function for multi-class semantic segmentation

I’m doing a semantic segmentation problem where each pixel may belong to one or more classes. However, I cannot find a suitable loss function to compute binary crossent loss over each pixel in the image. BCELoss requires a single scalar value as the target, while CrossEntropyLoss allows only one class for each pixel.
Is there any built-in loss for this problem (similar to binary_crossentropy in Keras), or I need to write a new loss?
Thank you very much.

1 Like

If you are dealing with a multi-label classification, nn.BCELoss should work fine, if you pass the target as a “multi-hot” encoded tensor.

target = torch.tensor([[0,1,0,1,0,0]], dtype=torch.float32)
output = torch.randn(1, 6, requires_grad=True)

criterion = nn.BCEWithLogitsLoss()
loss = criterion(output, target)
loss.backward()

In this small example, I just passed a single data sample with a target for two active classes.
For a semantic segmentation use case, each pixel should contain the corresponding label (0 or 1) in the channel dimension.

2 Likes

My tatget and output are of shape [B, C, H, W]. Positive weights is of shape [C].

target = torch.ones([1, 2, 5, 5], dtype=torch.float32)
output = torch.randn(1, 2, 5, 5, requires_grad=True)
positive_weights = torch.FloatTensor([2, 2])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weights)
loss = criterion(output, target)
loss.backward()

The error is:

RuntimeError: The size of tensor a (2) must match the size of tensor b (5) at non-singleton dimension 3

How to fix this?

It seems that the problem occurs in pos_weight = positive_weights.
From the docs

  • pos_weight ( Tensor , optional ) – a weight of positive examples. Must be a vector with length equal to the number of classes.

But why does it consider 5 as the number of classes?:thinking:

1 Like

As @MariosOreo said, it seems the pos_weight argument throws this error.
A quick fix might be to permute and view the output and target such that the two classes are in dim1:

loss = criterion(output.permute(0, 2, 3, 1).view(-1, 2), target.permute(0, 2, 3, 1).view(-1, 2))

or to expand the pos_weight tensor manually:

positive_weights = positive_weights[:, None, None].expand(-1, 5, 5)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weights)

However, it seems like unclear behavior to me, so feel free to post a Github issue to further discuss this use case.

2 Likes

Actually it’s a bug. When pos_weight was added to BCEWithLogits loss it wasn’t supposed to be used with per-pixel classifiers. So, broadcasting doesn’t work well in this case.

You can reshape pos_weight by adding two dummy dimensions for W and H to to work around this bug.

positive_weights = torch.FloatTensor([2, 2]).reshape(1, 2, 1, 1)

Unfortunately my code with implementation of pos_weight was rewritten from Python to C++ and I’m not sure if I can fix it right away.

1 Like

Hey, frist of all, if each pixel may belong to one or more classes, it means that you are in a multilabel segmentation task and not multi-class because multiclass, a pixel belongs to one of the classes.
I don’t know if you’ve thought of using the segmentation-models-pytorch (smp) library that also has a package of losses including DiceLoss, Focal Loss and Tversky Loss and you can choose the mode of loss : 'binary, multi-label, multi-class" : SMP’s defined losses