import torch
import torch.nn as nn
from torch.autograd.function import Function
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
loss = self.centerlossfunc(feat, label, self.centers)
loss /= (batch_size if self.size_average else 1)
return loss
class CenterlossFunc(Function):
@staticmethod
def forward(ctx, feature, label, centers):
ctx.save_for_backward(feature, label, centers)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0
@staticmethod
def backward(ctx, grad_output):
feature, label, centers = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new(centers.size(0)).fill_(1)
ones = centers.new(label.size(0)).fill_(1)
grad_centers = centers.new(centers.size()).fill_(0)
counts = counts.scatter_add_(0, label.long(), ones)
grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centers = grad_centers/counts.view(-1, 1)
return - grad_output * diff, None, grad_centers
What exactly is grad_output being passed in backward method?
A Function
is an elementary building block of the autograd engine. It should implement both the forward transformation: input -> output
and the backward transformation grad_output -> grad_input
. This backward operation corresponds to applying the chain rule for derivation:
If you have a loss L and you are looking for dL/dinput
then using the chain rule, you can use dL/dinput = dL/doutput * doutput/dinput
. The first term here is grad_output
and the second is what is implemented in the backward function.
Only question now I am left with is why there return - grad_output * diff, None, grad_centers, is “-” negative sign in grad_output.
I guess because even though the forward writes (feature - center_batch)
, diff = centers_batch - feature
. And so there is an extra -
to get the proper gradients.
Can I generalize this would be the best way to implement custom loss?
If you loss function is differentiable, you don’t actually need to implement a Function
. Just write an nn.Module
that contains the forward pass and the gradients will be computed for you (if they exist).
If it’s not differentiable, or the gradients that you want are not the one for the loss you compute, then yes creating a Function
that way is the way to go.