Multi-label Dice loss

class DICELossMultiClass(nn.Module):

def __init__(self):
    super(DICELossMultiClass, self).__init__()

def forward(self, output, mask):
    num_classes = output.size(-1)
    dice_eso = 0
    for i in range(num_classes):
        probs = torch.squeeze(output[:, :, :, :, i], -1)    # [batch_size, D, H, W, channel]
        target = torch.squeeze(mask[:, :, :, :, i], -1)     # one-hot :[batch_size, D, H, W, channel]

        num = probs * target
        num = torch.sum(num, 3)
        num = torch.sum(num, 2)
        num = torch.sum(num, 1)

        # print( num )

        den1 = probs * probs
        # print(den1.size())
        den1 = torch.sum(den1, 3)
        den1 = torch.sum(den1, 2)
        den1 = torch.sum(den1, 1)

        # print(den1.size())

        den2 = target * target
        # print(den2.size())
        den2 = torch.sum(den2, 3)
        den2 = torch.sum(den2, 2)
        den2 = torch.sum(den2, 1)

        # print(den2.size())
        eps = 0.0000001
        dice = 2 * ((num + eps) / (den1 + den2 + eps))
        # dice_eso = dice[:, 1:]
        dice_eso += dice

    loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
    return loss

I used the above function to computed the multi-label dice loss(the target label is encoded by one-hot), but I don’t know whether it is correct because I got some negative value for loss. I think the main reason is that one of channels is for background, which is set to 1 for background pixels,this is different from the binary semantic segmentation.
What can I do for this ?

1 Like