Custom Loss (autograd & Module ), What is the difference?

I found some codes of ‘dice loss’ for binary segmentation.
These are codes

1. Using autograd.

class Diceloss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target):
        ctx.save_for_backward(input, target)
        num = input.size()[0]
        eps = 0.0001
        ctx.inter = torch.dot(input.view(-1), target.view(-1))
        ctx.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * ctx.inter.float() + eps) / ctx.union.float()
        print(1-t)
        return (1-t)/num

    @staticmethod
    def backward(ctx, grad_output):
        input, target = ctx.saved_variables
        grad_input = grad_target = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * ctx.union - ctx.inter) \
                         / (ctx.union * ctx.union)
        if ctx.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target

and, I used this code like this.

criterion = Diceloss.apply
loss = (predicted_mask, true_mask)
loss.backward()

2. Using nn.Module

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)

        num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))

and, I used this code like this

criterion = BinaryDiceLoss()
loss = criterion(predicted_mask, true_mask)
loss.backward()

In this result, both seemed to run well but, the loss was not decreased in the case of first.
On the contrary, the second ran well.

And, I referenced this link “https://9bow.github.io/PyTorch-tutorials-kr-0.3.1/beginner/examples_autograd/two_layer_net_custom_function.html”.

but, the code using ‘autograd’ didn’t run well.

What is the difference between ‘autograd(forward, backward)’ and ‘nn.Module(only forward)’
Which case do I use autograd( including both forwarding and backwarding function) or nn.Module(forward function not including backward function)

Thanks!

Hi,

From a quick look, it seems like your Module version handles batch differently than the autograd version no?

Also once you are sure that the forward give the same thing, you can check the backward implementation of the autograd with: torch.autograd.gradcheck(Diceloss.apply, (sample_input, sample_target)), where the inputs are double precision, with the input requirering gradients.

Thanks. I will check this.
Can I give you another question?
Suppose that both forward and backward function of autograd are implemented correctly, What are the differences between two custom loss functions that are based on respectively ‘autograd’ and ‘nn.Module’?

The will give you exactly the same result (up to float numerical precision).
The one based on autograd might use less memory and be slightly faster if you have a spacial backward that is more efficient than the automatically generated one. But it is usually not recommended to do it as it is more complex and more error prone.

1 Like

Nice. Thanks!! It’s so helpful