I don’t think you should return the output_grads, as they were already passed to the module and won’t be used anymore. Instead you should return all input gradients, which would be passed to the “previous” layer (previous in the sense of the forward execution). Also use register_full_backward_hook, as register_backward_hook might not be working correctly.
From the docs:
The
grad_inputandgrad_outputare tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place ofgrad_inputin subsequent computations.grad_inputwill only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries ingrad_inputandgrad_outputwill beNonefor all non-Tensor arguments.
Here is a minimal example:
def backward_hook(m, input_gradients, output_gradients):
print('input_gradients {}'.format(input_gradients))
print('output_gradients {}'.format(output_gradients))
input_gradients = (torch.ones_like(input_gradients[0]), )
return input_gradients
conv = nn.Conv2d(1, 1, 3)
conv.register_full_backward_hook(backward_hook)
x = torch.randn(1, 1, 3, 3).requires_grad_()
out = conv(x)
out.mean().backward()
print(x.grad) # ones