Can anyone help me understand the following pytorch code?

import torch
import torch.nn as nn
from torch.autograd.function import Function

class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(CenterLoss, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.centerlossfunc = CenterlossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, label, feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
        loss = self.centerlossfunc(feat, label, self.centers)
        loss /= (batch_size if self.size_average else 1)
        return loss


class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centers):
        ctx.save_for_backward(feature, label, centers)
        centers_batch = centers.index_select(0, label.long())
        return (feature - centers_batch).pow(2).sum() / 2.0

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new(centers.size(0)).fill_(1)
        ones = centers.new(label.size(0)).fill_(1)
        grad_centers = centers.new(centers.size()).fill_(0)

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
        return - grad_output * diff, None, grad_centers

What exactly is grad_output being passed in backward method?

A Function is an elementary building block of the autograd engine. It should implement both the forward transformation: input -> output and the backward transformation grad_output -> grad_input. This backward operation corresponds to applying the chain rule for derivation:
If you have a loss L and you are looking for dL/dinput then using the chain rule, you can use dL/dinput = dL/doutput * doutput/dinput. The first term here is grad_output and the second is what is implemented in the backward function.

Only question now I am left with is why there return - grad_output * diff, None, grad_centers, is “-” negative sign in grad_output.

I guess because even though the forward writes (feature - center_batch), diff = centers_batch - feature. And so there is an extra - to get the proper gradients.

1 Like

Can I generalize this would be the best way to implement custom loss?

If you loss function is differentiable, you don’t actually need to implement a Function. Just write an nn.Module that contains the forward pass and the gradients will be computed for you :slight_smile: (if they exist).
If it’s not differentiable, or the gradients that you want are not the one for the loss you compute, then yes creating a Function that way is the way to go.

1 Like