I am trying to record the values of the gradients as they propagate in time through an RNN. Initially I thought this would be easily accomplished using the
register_hook function by calling it to each parameter, yet I now realize that for some reason, the hook is not called at each step in time, contrary to what I had understood. Take for example the following code:
import torch x = torch.randn((1, 1)) w = torch.ones((2, 1), requires_grad=1) z = w * x z *= w z *= w z *= w def print_grad(grad): print(grad) h = z.register_hook(print_grad) z.sum().backward()
This produces as output:
Instead of the same, but repeated 4 times.
What am I missing here? If nothing, is there a way to record the gradients at each point in time?
Thanks in advance!