(Help Appreciated) UNet Binary Image Segmentation Issue

Hello, I’m training a binary image segmentation model using the smp library in pytorch. The labels are imbalanced, where the # of background pixels are about 20 times more than the # of the actual class/label pixels on average. To tackle this, I’ve tried utilizing the dice loss function, the iou (Jaccard) loss function, and the focal loss function (with alpha), and also different composites of those three as well.

No matter which loss function I try, the same behavior happens. On training and validation, both training and validation loss goes down as number of epoch progresses. I made sure the distribution of the classes were equal in 5 folds (I’m doing a 5-fold CV), meaning that training and validation dataset is not an issue. I’ve also done enough image augmentations (colorjitter, noise, etc) and have visually confirmed that they have been applied consistently with the mask and the image. However, there are two odd behaviors no matter how much I try to play with the loss functions:

  1. After a certain epoch, (epoch ~= 100), the model begins to deteriorate in performance on a held out test set even though the training and validation loss continues to decrease.
  2. Validation loss at epoch = 1 starts out really high (0.4) compared to training loss (0.1), even though they both do end up decreasing as the epochs continue.

My guess is that the model eventually learns to “cheat” and just reports a blank image every time (since background is 20x more common) and get a lower loss. However, I’m frustrated because I think my loss function should be able to handle this issue. For some more information, I’m using the UnetPlusPlus model which is on the smp website with a resnet50 encoder. Below is my code for one of the loss function composites I used. These loss functions can be found in this link.

dice_loss_func = smp.losses.DiceLoss(mode=‘binary’, from_logits= True)
iou_loss_func = smp.losses.JaccardLoss(mode=‘binary’,from_logits = True)
focal_loss_func = smp.losses.FocalLoss(mode=‘binary’,alpha=0.95, gamma=2)
def loss_func(y_pred, y_true): #weighted avg of the two, maybe explore different weighting if possible?
return 0.2 * focal_loss_func(y_pred,y_true) + 0.5 * dice_loss_func(y_pred,y_true) + 0.2 * iou_loss_func(y_pred,y_true)

Any help would be appreciated, and I can provide more details if needed. Thank you!

@ chokevin8 It’s possible that your issue isn’t related to class imbalance, but something about your setup. Please could you try training the same model with a balanced class and see if you see the same behaviour? If you don’t then you can be more certain about the class imbalance being the issue.

Additionally, if it in fact is a class imbalance issue, then using weighted cross entropy loss (nn.CrossEntropyLoss with the weights argument) should be super useful.

If it’s not a class imbalance problem, try with a simple UNet model and see if that makes the problem go away.

The quality of data annotations will affect the model’s training. Additionally, this may be a valid case where you may want to use an optimizer with weight decay.