Manipulating gradients in backward

According to the docs the hook …

can optionally return a new gradient with respect to input that will be used in place of grad_input in subsequent computations.

In other words, if your hook returns fixed gradients, then those gradients will be used in all layers that come before it.

Need convincing…

a = Variable(torch.randn(2,2), requires_grad=True)
m = nn.Linear(2,1)
m(a).mean().backward()
print(a.grad) 
# shows a 2x2 tensor of non-zero values

def hook(module, grad_input, grad_output):
    # replace gradients with zeros
    return (torch.zeros(grad_input[0].size()),torch.zeros(grad_input[1].size()),torch.zeros(grad_input[2].size()),)

m.register_backward_hook(hook)

a.grad.zero_()
m(a).mean().backward()
print(a.grad)
# shows a 2x2 tensor of zeros
2 Likes