Number of output channels for binary segmentation

Hello, I am attempting a binary segmentation problem. I am confused about what output dimensions are expected for loss functions. For BCELoss() should the network and labels both have a shape of nBatchx1xhxw? And for CrossEntropyLoss() should the network output be nBatchx2xhxw and label be nBatchxhxw? It would be great if someone can clarify.
Thanks!

Yes, your explanation is correct and the difference would be:

  • nn.BCEWithLogitsLoss treats your use case as a binary segmentation use case directly where the labels indicate the negative (0) or positive (1) class. “Soft” labels are also allowed, but it depends on your actual use case how you want to interpret a target of e.g. 0.5. The model output and target are expected to have the same shape as [batch_size, 1, height, width] and the target should contain values in [0, 1]. Use nn.BCEWithLogitsLoss and make sure your model outputs logits for a better numerical stability compared to nn.BCELoss + sigmoid.
  • nn.CrossEntropyLoss can be used as a 2-class multi-class segmentation use case and your mentioned shapes are also correct. As you can see, the target tensor does not contain the class dimension, since it contains class labels now in the range [0, nb_classes], so 0 and 1 in your case since you are only dealing with 2 classes.

Thanks for the clarification. Can I also use output.view(-1) and target.view(-1) before passing to nn.BCEWithLogitsLoss? I also have a data imbalance problem (negative:positive ratio is 100) for which I am using pos_weight=torch.tensor(100).

It might work if you flatten the output and targets, but I would rather stick to the explicit shape of [batch_size, 1, height, width].
Yes, pos_weight is a proper way to try to reduce overfitting to a majority class.

Ok, when I try with pos_weight=None, the training is smooth but when I use the pos_weight=100, the training becomes very noisy and the loss doesn’t seem to converge fast enough (it does reduce but its very gradual). Could there be any reason for this? Also, am I specifying pos_weight correctly or should be a list with a length=2 or something?