One-hot encoding with autograd (Dice loss)

Hi,
I want to implement a dice loss for multi-class segmentation, my solution requires to encode the target tensor with one-hot encoding because I am working on a multi label problem. If you have a better solution than this, please feel free to share it.
This loss function needs to be differentiable in order to do backprop. I am not sure how to encode the target while keeping autograd working. I am currently having this error :
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

code based on @rogetrullo work. https://github.com/pytorch/pytorch/issues/1249

def dice_loss(output, target):
    """
    input is a torch variable of size BatchxnclassesxHxW representing log probabilities for each class
    target is a 1-hot representation of the groundtruth, shoud have same size as the input
    """
    encoded_target = Variable(output.data.clone())
    encoded_target[...] = 0
    encoded_target.scatter_(1,
                            target.view(target.size(0), 1,
                                        target.size(1), target.size(2)),
                            1)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    num = output * encoded_target  # b,c,h,w--p*g
    num = torch.sum(num, dim=3)  # b,c,h
    num = torch.sum(num, dim=2)

    den1 = output * output  # p^2
    den1 = torch.sum(den1, dim=3)  # b,c,h
    den1 = torch.sum(den1, dim=2)

    den2 = encoded_target * encoded_target  # g^2
    den2 = torch.sum(den2, dim=3)  # b,c,h
    den2 = torch.sum(den2, dim=2)  # b,c

    dice = (2 * num / (den1 + den2))

    dice_total = -1 * torch.sum(dice) / dice.size(0)
    return dice_total

If you think of a solution that does not requires one-hot encoding to evaluate the dice similarity of a multi-class problem, I am also interested !

Thanks

the inplace operation error comes from building encoded_target.

encoded_target is not differentiable anyways, so first build it and then wrap it in a Variable, like this:

 encoded_target = output.data.clone()
encoded_target[...] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
encoded_target = Variable(encoded_target)

You are right, I don’t even need it to be differentiable. Here is a new solution, however I would like to expand the original problem with a new feature : ignore_index

def dice_loss(output, target, weights=1):
    encoded_target = output.data.clone().zero_()
    encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    num = (output * encoded_target).sum(dim=3).sum(dim=2)
    den1 = output.pow(2).sum(dim=3).sum(dim=2)
    den2 = encoded_target.pow(2).sum(dim=3).sum(dim=2)

    dice = (2 * num / (den1 + den2)) * weights
    return dice.sum() / dice.size(0)

In semantic segmentation we generally have a label that we want to ignore from the loss, this requirement is already specified by the ignore_index parameter of NLLLoss.
I would like to implement the same for this dice loss, I already thought of two solutions but I don’t like them :

  • the worst : re-encode all the labels so that the ignore_index is a valid new label, which implies to modify my classifier layer. This is really ugly for a lot of reasons.
  • inside the loss function, remap ignore_label to a new label, expand the output to match the correct size, and finally ignore this label in the end. I don’t really like this solution neither, it involves copying+modifying the targets and expanding the channel dimension of the output tensor (I think).

If you have already faced this kind of problem, I would like to have your point of view on this.
Thanks !

1 Like

Here is my solution to the ignore_index feature, not sure this is 100% correct. I added some comment so you can understand the logic behind. This is simple masking of the tensors.

def dice_loss(output, target, weights=1, ignore_index=None):
    encoded_target = output.data.clone().zero_()
    if ignore_index is not None:
        # mask of invalid label
        mask = target == -1
        # clone target to not affect the original variable ?
        filtered_target = target.clone()
        # replace invalid label with whatever legal index value
        filtered_target[mask] = 0
        # one hot encoding
        encoded_target.scatter_(1, filtered_target.unsqueeze(1), 1)
        # expand the mask for the encoded target array
        mask = mask.unsqueeze(1).expand(output.data.size())
        # apply 0 to masked pixels
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    num = (output * encoded_target).sum(dim=3).sum(dim=2)
    den1 = output.pow(2)
    den2 = encoded_target.pow(2)
    if ignore_index is not None:
        # exclude masked values from den1
        den1[mask] = 0

    dice = 2 * (num / (den1 + den2).sum(dim=3).sum(dim=2)) * weights
    return -dice.sum() / dice.size(0)

Finally got something to work :

def dice_loss(output, target, weights=None, ignore_index=None):
    """
    output : NxCxHxW Variable
    target :  NxHxW LongTensor
    weights : C FloatTensor
    ignore_index : int index to ignore from loss
    """
    eps = 0.0001

    output = output.exp()
    encoded_target = output.detach() * 0
    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = mask.unsqueeze(1).expand_as(encoded_target)
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)

    if weights is None:
        weights = 1

    intersection = output * encoded_target
    numerator = 2 * intersection.sum(0).sum(1).sum(1)
    denominator = output + encoded_target

    if ignore_index is not None:
        denominator[mask] = 0
    denominator = denominator.sum(0).sum(1).sum(1) + eps
    loss_per_channel = weights * (1 - (numerator / denominator))

    return loss_per_channel.sum() / output.size(1)

It’s a combination of code seen on github, tested on a 2D semantic segmentation problem.

1 Like

Hi,
Why are you using loss_per_channel instead of finding the total loss of all channels? Are you getting multi loss. For example, please explain, a prediction of [10,3,5,5] with ground truth [10,1,5,5] will work?
Best

Hi,
If your example means you have 11 labels, using this loss will average 11 dice losses, one for each channel. I chose to have the loss per channel in case I need to weight the loss of each channels. This function will return the global dice loss, not the loss per channels.

Hi trypag,
Thanks a lot for support. I have developed this code for dice similarity measure following your code.

def dice_loss(self,output, target, weights=None, ignore_index=None):
    # output : NxCxHxW Variable of float tensor
    # target :  NxHxW long tensor
    # weights : C float tensor
    # ignore_index : int value to ignore from loss
    smooth = 1.
    loss = 0.

    output = output.exp()   # computes the exponential of each element ie. for 0 it finds 10
    encoded_target = output.data.clone().zero_() # make output size array and initialize with zeros
    #ignore_index=1

    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = mask.unsqueeze(1).expand_as(encoded_target)
        encoded_target[mask] = 0
    else:
      unseq=target.long()  # here
      unseq=unseq.data   # here
      encoded_target.scatter_(1, unseq, 1)
     

    encoded_target = Variable(encoded_target)

    if weights is None:
        weights = Variable(torch.ones(output.size(1)).type_as(output.data))

    intersection = output * encoded_target
    numerator = 2 * intersection.sum(3).sum(2).sum(0) + smooth
    denominator = (output + encoded_target).sum(3).sum(2).sum(0) + smooth
    loss_per_channel = weights * (1 - (numerator / denominator)) # weights may be directly multiplied

    return loss_per_channel.sum() / output.size(1)

The code seems to working fine, here are two things to consider. 1) what is the purpose of getting exponential of target variable and 2) i had to change few lines as shown in bold as without them i was getting errors. Can you have a look on this code. Thirdly as given in “Generalized Dice overlap as a deep learning loss function for highly unbalanced segmentation” paper, the weights should be multiplied as (1-weights*(numerator/denominator)) and it is also given here https://cmiclab.cs.ucl.ac.uk/CMIC/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py in generalized dice loss function.
Best

1 Like

Output is not the target variable, it’s the output of my model, the feature vectors.
I am using the exponential because the output of my model is log(softmax), so as to obtain the softmax I use the exponential of log(softmax). The original formulation is written with the softmax, I just had to adapt to my model.

I don’t see any code in bold, I noticed you changed encoded_target = output.data.clone().zero_(), the original was encoded_target = output.detach() * 0, it should have worked if output is a Variable type.

Hi @trypag, I have a question for the line of " numerator = 2 * intersection.sum(0).sum(1).sum(1)". Here you made the summation over the batch first, that doesn’t seem right to me. The dice loss should be calculated over every example and then summing them together. Am I correct?

the summation over batch first in the calculation of numerator and denominator make the codes compute some approximation of dice, but not exactly dice loss.