```
import torch
import torch.nn as nn
from torch.autograd.function import Function
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
loss = self.centerlossfunc(feat, label, self.centers)
loss /= (batch_size if self.size_average else 1)
return loss
class CenterlossFunc(Function):
@staticmethod
def forward(ctx, feature, label, centers):
ctx.save_for_backward(feature, label, centers)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0
@staticmethod
def backward(ctx, grad_output):
feature, label, centers = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new(centers.size(0)).fill_(1)
ones = centers.new(label.size(0)).fill_(1)
grad_centers = centers.new(centers.size()).fill_(0)
counts = counts.scatter_add_(0, label.long(), ones)
grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centers = grad_centers/counts.view(-1, 1)
return - grad_output * diff, None, grad_centers
```

What exactly is grad_output being passed in backward method?

A `Function`

is an elementary building block of the autograd engine. It should implement both the forward transformation: `input -> output`

and the backward transformation `grad_output -> grad_input`

. This backward operation corresponds to applying the chain rule for derivation:

If you have a loss L and you are looking for `dL/dinput`

then using the chain rule, you can use `dL/dinput = dL/doutput * doutput/dinput`

. The first term here is `grad_output`

and the second is what is implemented in the backward function.

Only question now I am left with is why there **return - grad_output * diff, None, grad_centers**, is **“-”** negative sign in grad_output.

I guess because even though the forward writes `(feature - center_batch)`

, `diff = centers_batch - feature`

. And so there is an extra `-`

to get the proper gradients.

Can I generalize this would be the best way to implement custom loss?

If you loss function is differentiable, you don’t actually need to implement a `Function`

. Just write an `nn.Module`

that contains the forward pass and the gradients will be computed for you (if they exist).

If it’s not differentiable, or the gradients that you want are not the one for the loss you compute, then yes creating a `Function`

that way is the way to go.