Dice Loss Autograd

I can across this implementation of Dice:

from torch.autograd import Function, Variable

class DiceCoeff(Function):
“”“Dice coeff for individual examples”""

def forward(self, input, target):
    self.save_for_backward(input, target)
    eps = 0.0001
    self.inter = torch.dot(input.view(-1), target.view(-1))
    self.union = torch.sum(input) + torch.sum(target) + eps

    t = (2 * self.inter.float() + eps) / self.union.float()
    return t

# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):

    input, target = self.saved_variables
    grad_input = grad_target = None

    if self.needs_input_grad[0]:
        grad_input = grad_output * 2 * (target * self.union - self.inter) \
                     / (self.union * self.union)
    if self.needs_input_grad[1]:
        grad_target = None

    return grad_input, grad_target

Uptill now, my dice score had no backward method, it was just the regular view(-1) and intersection over union. Does that mean that nothing was being learnt as backward method was not implemented?

Can someone respond to this??

If you use PyTorch methods only, you won’t need to write the backward pass yourself in most cases.
To check, if the dice loss was working before, you could call backward on it and look at the gradients in your model.
You could however write your own backward method, e.g. if you have an idea to speed up the computation.

hi , i have this problem too.My model can’t learn during the training .I changed Dice loss function for many times but nothing changed.
but i didn’t use custom backward function,i just used forward function. Have you used it ? Was it useful?

I have done it before. Dice loss can be coded without defining any custom function. I did it both ways, defining custom function and also using nn.Module to check both gave same results. Please check your code and follow the guidelines on how to define custom function, it varies slightly based on which version of PyTorch you are using.

1 Like