Multiclass Image Segmentation

I am working on multi-class image segmentation and currently having challenges regarding my dataset. The labels (ground truth/target) are already one-hot encoded for the two class labels but the background are not given. Firstly, is the annotation or labeling of the background necessary for the performance of the model since it will be dropped during prediction or inference?

Secondly, due to the highly imbalance nature of the dataset, suggest approaches as read on the forum is either to use weighted sampler or weighted loss function. As my preferred choice of weighted dice loss function, how will I compute the weight for each class (suggested approaches are targeted at classification task not segmentation)?

Finally, what is the significance of a combine loss function i.e dice loss and cross-entropy loss function?

I appreciate your explanations and suggestions in advance.

The background class would be necessary, if your model is supposed to predict the class of each pixel.
During training you could ignore the background class pixels, but then your model won’t be able to predict the background during evaluation and the predicted segmentation would contain only the two valid classes.

I’m unsure if weighted sampling would work nicely for segmentation use cases, since each sample would most likely contain more than a single class. Drawing a particular sample would thus increase the class count of multiple classes and you would end up with the limitations for a multi-label classification use case.

Both losses will be used to calculate the gradients and each parameter will accumulate the gradient from them.

1 Like

Thanks @ptrblck for the explanation.

As per reporting the model performance both at training and validation phase, the background class is not needed. However, I am of the idea that it might help when trying to visualized the prediction of the model.

Moreso, in computing the background, I stacked the annotated (masked) label of the given classes and consider its complement i.e having 1’s for 0’s of stacked label and vice versa with little attention paid to overlapping pixel region.

I appreciate your clear explanation. In this case, how will one handling class imbalance in segmentation cases - in particular get the class weights for each class rather than using robust network like HighResNet as someone suggested in one of the forum?

You could use a weighted loss e.g. by using an unreduced loss (reduction='none'), multiplying each loss “pixel” by the class weight, reducing it, and finally calling backward().
I’m sure there are other valid approaches. E.g. is focal loss only used in detection models or could it also be used for segmentation use cases?

1 Like

Thanks for the explanation.