I am working on image segmentation task with image and 7 [2, 5] labels given in .npy format. I checked the class frequency in each by taking the sum of ones since the labels are readily one-hot encoded. Based on suggested approach, I using the weighted (median sampling frequency) dice+CrossEntropy loss but the performance is not improving.
Are there other approach that can be used to handle such imbalance for improved performance?
Class imbalance in image segmentation is a notoriously difficult problem. I got decent results with 2 things :
- using an architecture with multiscale representation, like HrNetv2
- using the loss suggested in this paper : https://arxiv.org/abs/1708.02002, the idea is that the loss isn’t decreased much for confident correctly classified pixels. This way, most of the parameters updates are concentrated on classes that are difficult to classify.
I appreciate your contributions.