Hello all, I am using dice loss for multiple class (4 classes problem). I want to use weight for each class at each pixel level. So, my weight will have size of BxCxHxW
(C=4) in my case. How can I use the weight to assign to dice loss? This is my current solution that multiple the weight with the input (network prediction) after softmax
class SoftDiceLoss(nn.Module):
def __init__(self, n_classes):
super(SoftDiceLoss, self).__init__()
self.one_hot_encoder = One_Hot(n_classes).forward
self.n_classes = n_classes
def forward(self, input, target, weight):
smooth = 0.01
batch_size = input.size(0)
input = F.softmax(input, dim=1)
input = input*weight
input = input.view(batch_size, self.n_classes, -1)
target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
inter = torch.sum(input * target, 2) + smooth
union = torch.sum(input, 2) + torch.sum(target, 2) + smooth
score = torch.sum(2.0 * inter / union)
score = 1.0 - score / (float(batch_size) * float(self.n_classes))
return score
And the second solution is that multiply the weight in the inter and union position
class SoftDiceLoss(nn.Module):
def __init__(self, n_classes):
super(SoftDiceLoss, self).__init__()
self.one_hot_encoder = One_Hot(n_classes).forward
self.n_classes = n_classes
def forward(self, input, target, weight):
smooth = 0.01
batch_size = input.size(0)
input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
weight = weight.view(batch_size, self.n_classes, -1)
inter = torch.sum(input * target * weight, 2) + smooth
union = torch.sum(input*weight, 2) + torch.sum(target*weight, 2) + smooth
score = torch.sum(2.0 * inter / union)
score = 1.0 - score / (float(batch_size) * float(self.n_classes))
return score
Which one is correct?