I wrote this snippet below to try and understand what’s going on with these hooks.
class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.fc1 = nn.Linear(10,5) self.fc2 = nn.Linear(5,1) self.fc1.register_forward_hook(self._forward_hook) self.fc1.register_backward_hook(self._backward_hook) def forward(self, inp): return self.fc2(self.fc1(inp)) def _forward_hook(self, module, input, output): print(type(input)) print(len(input)) print(type(output)) print(input.shape) print(output.shape) print() def _backward_hook(self, module, grad_input, grad_output): print(type(grad_input)) print(len(grad_input)) print(type(grad_output)) print(len(grad_output)) print(grad_input.shape) print(grad_input.shape) print(grad_output.shape) print() model = Model() out = model(torch.tensor(np.arange(10).reshape(1,1,10), dtype=torch.float32)) out.backward()
<class 'tuple'> 1 <class 'torch.Tensor'> torch.Size([1, 1, 10]) torch.Size([1, 1, 5]) <class 'tuple'> 2 <class 'tuple'> 1 torch.Size([1, 1, 5]) torch.Size() torch.Size([1, 1, 5])
You can also follow the CNN example here. In fact, it’s needed to understand the rest of my question.
I have a few questions:
I would normally think that
grad_input(backward hook) should be the same shape as
output(forward hook) because when we go backwards, the direction is reversed. But the CNN example seems to indicate otherwise. I’m still a bit confused. Which way around is it?
grad_outputthe same shape on my
Linearlayer here? Regardless of the answer to my question 1, at least one of them should be
torch.Size([1, 1, 10])right?
What’s with the second element of the tuple
grad_input? In the CNN case I copy pasted the example and did
torch.Size([20, 10, 5, 5]). So I presume it’s the gradients of the weights. I also ran
torch.Size(). So it seemed clear I was looking at the gradients of the biases. But then in my
grad_inputis length 2, so I can only access up to
grad_input, which seems to be giving me the gradients of the biases. So then where are the gradients of the weights?
In summary, there are two apparent contradictions between the behaviour of the backwards hook in the cases of
Conv2d and `Linear’ modules. This has left me totally confused about what to expect with this hook.
Thanks for your help!