Loss function for multi-class semantic segmentation

Actually it’s a bug. When pos_weight was added to BCEWithLogits loss it wasn’t supposed to be used with per-pixel classifiers. So, broadcasting doesn’t work well in this case.

You can reshape pos_weight by adding two dummy dimensions for W and H to to work around this bug.

positive_weights = torch.FloatTensor([2, 2]).reshape(1, 2, 1, 1)

Unfortunately my code with implementation of pos_weight was rewritten from Python to C++ and I’m not sure if I can fix it right away.

1 Like