You could pass a function as the hook to register_hook, which will be called every time the gradient is calculated.
This might be useful for debugging purposes, e.g. just printing the gradient or its statistics, or you could of course manipulate the gradient in a custom way, e.g. normalizing it somehow etc.
x = torch.randn(1, 1)
w = torch.randn(1, 1, requires_grad=True)
w.register_hook(lambda x: print(x))
y = torch.randn(1, 1)
out = x * w
loss = (out - y)**2
loss.register_hook(lambda x: print(x))
loss.mean().backward(gradient=torch.tensor([0.1])) # prints the gradient in w and loss
@ptrblck I was wondering if it is possible to set requires_grad = True for the registered hooks.
More specifically, I am registering hooks for a recurrent network and want to know the gradients of the gradients (second derivative) i.e. Second derivate of the loss w.r.t. each hidden state.
I could not find a method to do this directly.