Loss function for multi-class semantic segmentation

As @MariosOreo said, it seems the pos_weight argument throws this error.
A quick fix might be to permute and view the output and target such that the two classes are in dim1:

loss = criterion(output.permute(0, 2, 3, 1).view(-1, 2), target.permute(0, 2, 3, 1).view(-1, 2))

or to expand the pos_weight tensor manually:

positive_weights = positive_weights[:, None, None].expand(-1, 5, 5)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weights)

However, it seems like unclear behavior to me, so feel free to post a Github issue to further discuss this use case.

2 Likes