Problem with Unet multiclass image segmentation and overlapping classes

Hi,
I’m working on a personal project, and I’m trying to use UNet to mask the lead and the big boxes from the ECG paper into two different classes. I created my dataset containing 200 images of ECG paper, lead masks, and big boxes masks. However, my problem is that the model always predicts both classes the same way, with both classes being visible in the other class’s prediction. for example:
Figure_1
my image shape is [N, C, H, W] (batch, class, height, width) and the class dims are:
[0, 0] if the pixel is not in any classes, example: [N, [0, 0], H, W]
[1, 0] if the pixel is in the class lead, example: [N, [1, 0], H, W]
[0, 1] if the pixel is in the class boxes, example: [N, [0, 1], H, W]
[1, 1] if the pixel is in both classes, example: [N, [1, 1], H, W]
I’m using Dice Loss and the Adam Optimizer. The model outputs in the shape of [N, 2, H, W] (2 for the two classes), and I squeeze then split them using numpy: [0, :, :] for the lead class and [1, :, :] for the boxes class into their respective variables.
I don’t understand why the model predicts both classes in both classes. Could it be that my way of dividing them into two different arrays?
(there are some minor differences between both class predictions, so I know they are not exact copy of each other.)

Could you check how your criterion processes the model outputs and if multi-label targets are supported? For the sake of debugging you could also pass the model output to nn.BCEWithLogitsLoss and compare it to your current approach.

Hi,
the Dice Loss is the only part that I didn’t write myself and I found its implementation online:

class MultilabelDiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(MultilabelDiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.tanh(inputs)
        targets = targets.float()

        inputs = inputs.view(inputs.size(0), inputs.size(1), -1)
        targets = targets.view(targets.size(0), targets.size(1), -1)
        
        inputs = (inputs + 1) / 2
        targets = (targets + 1) / 2

        intersection = (inputs * targets).sum(2)
        union = inputs.sum(2) + targets.sum(2)
        
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        
        loss = 1 - dice
        return loss.mean()

and about the nn.BCEWithLogitsLoss, it gives me negative loss values from the first epoch.
also, one thing that I forgot to mention I that the pixel values are normalized between -1 and 1. so that might be the reason for negative loss values with nn.BCEWithLogitsLoss.

I changed the pixel values to be between 0 and 1, which fixed the negative loss value. then I tried to train it with nn.BCEWithLogitsLoss. after 60 epochs, the results were the same as with Dice loss.