Backward hooks Explanation!

import torch
import torch.nn as nn
import torch.optim as optim

class ThreeLayerNet(nn.Module):
def init(self):
super(ThreeLayerNet, self).init()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 50)
self.fc3 = nn.Linear(50, 1)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

def backward_hook(module, grad_input, grad_output):
for i in grad_input:
print(i.shape)

model = ThreeLayerNet()
model.fc2.register_backward_hook(backward_hook)
input_data = torch.randn(1, 10)
target = torch.randn(1, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
output = model(input_data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

I was trying to printout input gradients for fc2 layer as shown in the hookup function and i was getting a tuple of
(torch.Size([50])
torch.Size([1, 20])
torch.Size([20, 50]))

And on printing output gradients i was getting torch.Size([1, 50])

Intuitively fc1->fc2->fc3 is the forward pass , And The grad_input contains the gradients wrt to the input of the layer ( So grads w.r.t fc1 ) . Similarly, the grad_output contains the gradients wrt to the output of the layer ( So grads w.r.t fc3 ). But its super confusing why i got such output :confused:

Your code raises a warning which you should not ignore:

UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.

Using .register_full_backward_hook returns a grad_input shape of [1, 20].

1 Like

Ahhhhhh my bad its working now :slight_smile: