Weighted pixelwise for multiple classes Dice Loss

Hello all, I am using dice loss for multiple class (4 classes problem). I want to use weight for each class at each pixel level. So, my weight will have size of BxCxHxW (C=4) in my case. How can I use the weight to assign to dice loss? This is my current solution that multiple the weight with the input (network prediction) after softmax

class SoftDiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(SoftDiceLoss, self).__init__()
        self.one_hot_encoder = One_Hot(n_classes).forward
        self.n_classes = n_classes

    def forward(self, input, target, weight):
        smooth = 0.01
        batch_size = input.size(0)

        input = F.softmax(input, dim=1)
        input = input*weight
        input = input.view(batch_size, self.n_classes, -1)
        target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)

        inter = torch.sum(input * target, 2) + smooth
        union = torch.sum(input, 2) + torch.sum(target, 2) + smooth

        score = torch.sum(2.0 * inter / union)
        score = 1.0 - score / (float(batch_size) * float(self.n_classes))

        return score

And the second solution is that multiply the weight in the inter and union position

class SoftDiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(SoftDiceLoss, self).__init__()
        self.one_hot_encoder = One_Hot(n_classes).forward
        self.n_classes = n_classes

    def forward(self, input, target, weight):
        smooth = 0.01
        batch_size = input.size(0)

        input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
        target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
        weight = weight.view(batch_size, self.n_classes, -1)

        inter = torch.sum(input * target * weight, 2) + smooth
        union = torch.sum(input*weight, 2) + torch.sum(target*weight, 2) + smooth

        score = torch.sum(2.0 * inter / union)
        score = 1.0 - score / (float(batch_size) * float(self.n_classes))

        return score

Which one is correct?

1 Like

Can you share your One_Hot(n_classes).forward?