Hey all! I’m working with a dataset for segmentation that has your typical RGB images as well as index-based masks (so the pixel values in the mask represent the classes at those positions), and I’m wondering what the best way to train this dataset would be with a simple U-Net. Would it be better to do standard CrossEntropy loss using the index masks, or would it be better to split the masks up into one hot vector format and use BCELoss with each channel? Or is there a better loss for this type of dataset entirely? I’ve seen dice loss, but no efficient implementations for multi-class scenarios.
Semantic segmentation: CrossEntropyLoss with single channel index mask vs. BCELoss with multi-channel one hot vector mask?
Hello Yuerno -
Conventional wisdom and common practice suggest that you
should use CrossEntropyLoss for a multi-class classification
problem. Reading between the lines, I understand your
segmentation problem to be a straightforward (multi-label)
classification problem, so I would think that CrossEntropyLoss
would be your preferred choice.
Also, if you propose using BCELoss, to be concrete, you need
to specify how you intend to combine the BCELoss for each
of the classes together to get a single loss function to optimize.
Lastly, people argue that dice loss works well (better than
CrossEntropyLoss?) when your classes are imbalanced, but
dice loss – its gradients, etc. – are mathematically unwieldy,
and this is a practical disadvantage.
I would recommend that you start with CrossEntropyLoss,
and only if that isn’t working well on your problem (or maybe
if your classes are highly imbalanced) also try dice loss and
compare its performance to that which you achieved with
CrossEntropyLoss. (I don’t think you should use just dice loss
without comparing it to CrossEntropyLoss.)
(I’m not aware of any scheme for combining together multiple
BCELosses that people prefer to CrossEntropyLoss.)
Hi Frank! Thanks so much for this detailed response, it really helped clarify things for me. Your assessment in reading between the lines was absolutely correct, so I think I’ll be sticking with just regular CrossEntropyLoss for now and see how that goes.