I’m doing a semantic segmentation problem where each pixel may belong to one or more classes. However, I cannot find a suitable loss function to compute binary crossent loss over each pixel in the image. BCELoss requires a single scalar value as the target, while CrossEntropyLoss allows only one class for each pixel.
Is there any built-in loss for this problem (similar to binary_crossentropy in Keras), or I need to write a new loss?
Thank you very much.
Loss function for multi-class semantic segmentation
If you are dealing with a multi-label classification, nn.BCELoss
should work fine, if you pass the target as a “multi-hot” encoded tensor.
target = torch.tensor([[0,1,0,1,0,0]], dtype=torch.float32)
output = torch.randn(1, 6, requires_grad=True)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(output, target)
loss.backward()
In this small example, I just passed a single data sample with a target for two active classes.
For a semantic segmentation use case, each pixel should contain the corresponding label (0 or 1) in the channel dimension.
My tatget and output are of shape [B, C, H, W]. Positive weights is of shape [C].
target = torch.ones([1, 2, 5, 5], dtype=torch.float32)
output = torch.randn(1, 2, 5, 5, requires_grad=True)
positive_weights = torch.FloatTensor([2, 2])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weights)
loss = criterion(output, target)
loss.backward()
The error is:
RuntimeError: The size of tensor a (2) must match the size of tensor b (5) at non-singleton dimension 3
How to fix this?
It seems that the problem occurs in pos_weight = positive_weights
.
From the docs
- pos_weight ( Tensor , optional ) – a weight of positive examples. Must be a vector with length equal to the number of classes.
But why does it consider 5 as the number of classes?
As @MariosOreo said, it seems the pos_weight
argument throws this error.
A quick fix might be to permute and view the output and target such that the two classes are in dim1:
loss = criterion(output.permute(0, 2, 3, 1).view(-1, 2), target.permute(0, 2, 3, 1).view(-1, 2))
or to expand the pos_weight
tensor manually:
positive_weights = positive_weights[:, None, None].expand(-1, 5, 5)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weights)
However, it seems like unclear behavior to me, so feel free to post a Github issue to further discuss this use case.
Confusing error using BCEWithLogitsLoss with weighted loss
Actually it’s a bug. When pos_weight
was added to BCEWithLogits loss it wasn’t supposed to be used with per-pixel classifiers. So, broadcasting doesn’t work well in this case.
You can reshape pos_weight
by adding two dummy dimensions for W and H to to work around this bug.
positive_weights = torch.FloatTensor([2, 2]).reshape(1, 2, 1, 1)
Unfortunately my code with implementation of pos_weight was rewritten from Python to C++ and I’m not sure if I can fix it right away.
Confusing error using BCEWithLogitsLoss with weighted loss