Hi,
I want to implement a dice loss for multi-class segmentation, my solution requires to encode the target tensor with one-hot encoding because I am working on a multi label problem. If you have a better solution than this, please feel free to share it.
This loss function needs to be differentiable in order to do backprop. I am not sure how to encode the target while keeping autograd working. I am currently having this error : RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
You are right, I don’t even need it to be differentiable. Here is a new solution, however I would like to expand the original problem with a new feature : ignore_index
def dice_loss(output, target, weights=1):
encoded_target = output.data.clone().zero_()
encoded_target.scatter_(1, target.unsqueeze(1), 1)
encoded_target = Variable(encoded_target)
assert output.size() == encoded_target.size(), "Input sizes must be equal."
assert output.dim() == 4, "Input must be a 4D Tensor."
num = (output * encoded_target).sum(dim=3).sum(dim=2)
den1 = output.pow(2).sum(dim=3).sum(dim=2)
den2 = encoded_target.pow(2).sum(dim=3).sum(dim=2)
dice = (2 * num / (den1 + den2)) * weights
return dice.sum() / dice.size(0)
In semantic segmentation we generally have a label that we want to ignore from the loss, this requirement is already specified by the ignore_index parameter of NLLLoss.
I would like to implement the same for this dice loss, I already thought of two solutions but I don’t like them :
the worst : re-encode all the labels so that the ignore_index is a valid new label, which implies to modify my classifier layer. This is really ugly for a lot of reasons.
inside the loss function, remap ignore_label to a new label, expand the output to match the correct size, and finally ignore this label in the end. I don’t really like this solution neither, it involves copying+modifying the targets and expanding the channel dimension of the output tensor (I think).
If you have already faced this kind of problem, I would like to have your point of view on this.
Thanks !
Here is my solution to the ignore_index feature, not sure this is 100% correct. I added some comment so you can understand the logic behind. This is simple masking of the tensors.
def dice_loss(output, target, weights=1, ignore_index=None):
encoded_target = output.data.clone().zero_()
if ignore_index is not None:
# mask of invalid label
mask = target == -1
# clone target to not affect the original variable ?
filtered_target = target.clone()
# replace invalid label with whatever legal index value
filtered_target[mask] = 0
# one hot encoding
encoded_target.scatter_(1, filtered_target.unsqueeze(1), 1)
# expand the mask for the encoded target array
mask = mask.unsqueeze(1).expand(output.data.size())
# apply 0 to masked pixels
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
encoded_target = Variable(encoded_target)
assert output.size() == encoded_target.size(), "Input sizes must be equal."
assert output.dim() == 4, "Input must be a 4D Tensor."
num = (output * encoded_target).sum(dim=3).sum(dim=2)
den1 = output.pow(2)
den2 = encoded_target.pow(2)
if ignore_index is not None:
# exclude masked values from den1
den1[mask] = 0
dice = 2 * (num / (den1 + den2).sum(dim=3).sum(dim=2)) * weights
return -dice.sum() / dice.size(0)
Hi,
Why are you using loss_per_channel instead of finding the total loss of all channels? Are you getting multi loss. For example, please explain, a prediction of [10,3,5,5] with ground truth [10,1,5,5] will work?
Best
Hi,
If your example means you have 11 labels, using this loss will average 11 dice losses, one for each channel. I chose to have the loss per channel in case I need to weight the loss of each channels. This function will return the global dice loss, not the loss per channels.
Hi trypag,
Thanks a lot for support. I have developed this code for dice similarity measure following your code.
def dice_loss(self,output, target, weights=None, ignore_index=None):
# output : NxCxHxW Variable of float tensor
# target : NxHxW long tensor
# weights : C float tensor
# ignore_index : int value to ignore from loss
smooth = 1.
loss = 0.
output = output.exp() # computes the exponential of each element ie. for 0 it finds 10
encoded_target = output.data.clone().zero_() # make output size array and initialize with zeros
#ignore_index=1
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
unseq=target.long() # here
unseq=unseq.data # here
encoded_target.scatter_(1, unseq, 1)
encoded_target = Variable(encoded_target)
if weights is None:
weights = Variable(torch.ones(output.size(1)).type_as(output.data))
intersection = output * encoded_target
numerator = 2 * intersection.sum(3).sum(2).sum(0) + smooth
denominator = (output + encoded_target).sum(3).sum(2).sum(0) + smooth
loss_per_channel = weights * (1 - (numerator / denominator)) # weights may be directly multiplied
return loss_per_channel.sum() / output.size(1)
The code seems to working fine, here are two things to consider. 1) what is the purpose of getting exponential of target variable and 2) i had to change few lines as shown in bold as without them i was getting errors. Can you have a look on this code. Thirdly as given in “Generalized Dice overlap as a deep learning loss function for highly unbalanced segmentation” paper, the weights should be multiplied as (1-weights*(numerator/denominator)) and it is also given here https://cmiclab.cs.ucl.ac.uk/CMIC/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py in generalized dice loss function.
Best
Output is not the target variable, it’s the output of my model, the feature vectors.
I am using the exponential because the output of my model is log(softmax), so as to obtain the softmax I use the exponential of log(softmax). The original formulation is written with the softmax, I just had to adapt to my model.
I don’t see any code in bold, I noticed you changed encoded_target = output.data.clone().zero_(), the original was encoded_target = output.detach() * 0, it should have worked if output is a Variable type.
Hi @trypag, I have a question for the line of " numerator = 2 * intersection.sum(0).sum(1).sum(1)". Here you made the summation over the batch first, that doesn’t seem right to me. The dice loss should be calculated over every example and then summing them together. Am I correct?
the summation over batch first in the calculation of numerator and denominator make the codes compute some approximation of dice, but not exactly dice loss.