Recording gradients in RNNs at each point in time

Hi,

I am trying to record the values of the gradients as they propagate in time through an RNN. Initially I thought this would be easily accomplished using the register_hook function by calling it to each parameter, yet I now realize that for some reason, the hook is not called at each step in time, contrary to what I had understood. Take for example the following code:

import torch

x = torch.randn((1, 1))
w = torch.ones((2, 1), requires_grad=1)

z = w * x
z *= w
z *= w
z *= w

def print_grad(grad):
    print(grad)

h = z.register_hook(print_grad)
z.sum().backward()

This produces as output:

tensor([[1.],  [1.]])

Instead of the same, but repeated 4 times.

What am I missing here? If nothing, is there a way to record the gradients at each point in time?

Thanks in advance!

The hook will give you the gradient at the point where the hook is registered.
If you want to get it after every computation, you need to register one after every computation.

1 Like

Oh, I see. I guess for feed-forward models that’s good enough. But for recurrent networks it makes tracing gradients backward in time a bit hard.

Thanks for the answer!

Well if you have a single call in a for-loop, you will still have one call to the hook per time it was applied. So that works for rnn as well :slight_smile:

def print_grad(grad):
    print(grad)

for i in range(4):
  z *= w
  z.register_hook(print_grad)

That will register one hook per iteration of the loop.