Understanding of Backward Hook

I recently got to know about register_bcakward_hook and register_forward_hook for nn.Module.

I have some queries about register_backward_hook. This is the sample code I am using

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4, 1)

    def forward(self, x):
        o = self.fc1 (x)
        o = self.relu1 (o)
        o = self.fc2 (o)
        return o


forward_values = {}
backward_values = {}

# Define the forward hook function
def hook_fn_forward(module, inp, out):
    forward_values[module] = {}
    forward_values[module]["input"] = inp
    forward_values[module]["output"] = out
    
def hook_fn_backward(module, inp_grad, out_grad):
    backward_values[module] = {}
    backward_values[module]["input"] = inp_grad
    backward_values[module]["output"] = out_grad


model = Model()

modules = model.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)
    module.register_backward_hook(hook_fn_backward)

# batch size of 1 -> shape = (1,3)
x = torch.tensor([[1.0, 1.0, 1.0]])
o = model(x)
o.backward()

I was checking the gradients of input and output for layer self.fc2 (last layer). The output looks like this

Output gradient 1 makes sense. Why is the input gradient is tuple and why it has 3 values?
Can anyone explain what these depict?

I am using PyTorch version 1.7.1

1 Like

Hi,

Unfortunately the Module backward hooks have been broken forever for such “complex” model.
If you’re using a recent version of pytorch, you can use the “full” versions Module — PyTorch 1.9.0 documentation that will have the expected behavior.

2 Likes

Thanks. I updated to 1.9.0 and it works as expected.

1 Like