From docs, we know that we only need to write __init__ and forward if extending torch.nn.
However, here it adds backward as below:
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
self.target = target.detach() * weight
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_variables=True):
self.loss.backward(retain_variables=retain_variables)
return self.loss
The reason of this is just because we want to compute the gradient wrt the computed loss self.loss, which is not the output of the forward. So I have overwritten the backward function to tell I just want to backward through this parameter.
Hi alexis,
Thank you for your reply. The following is something I think about it.
As neural style gets its loss in hidden layer. In the original implementation of Torch, take ContentLoss for example:
function ContentLoss:updateGradInput(input, gradOutput)
if self.mode == 'loss' then
if input:nElement() == self.target:nElement() then
self.gradInput = self.crit:backward(input, self.target)
end
if self.normalize then
self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
end
self.gradInput:mul(self.strength)
-- This is where the layer loss has taken its effects.
self.gradInput:add(gradOutput)
else
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
end
return self.gradInput
end
We can see self.gradInput:add(gradOutput) means the “hidden layer gradient” can be added to the main gradient flow.
However, in pytorch, we don’t bother to add it ourselves. As pytorch’s autograd mechanics, the gradient of the node will be accumulated automatically. So for the back-propagation of hidden gradient, we just need to overwritten the backward function:
In fact, loss.backward(...) call the backward function of the MSELoss wich already implements the back-propagation’s line. If I wanted to really create my own loss function, I would have had to implement the backward pass with such a line.