Hello,
Let’s say I have a model of the following layers :
Linear Layer → Softmax.
how can I write a backward hook in order to ignore the gradient computation of the softmax layer?
You could use use register_full_backward_hook
and return the grad_ouput
:
def hook(module, grad_input, grad_output):
return grad_output
x = torch.randn(1, 3, requires_grad=True)
lin = nn.Linear(3, 3, bias=False)
lin.register_full_backward_hook(hook)
out = lin(x)
out.mean().backward()
print(x.grad)