I recently got to know about register_bcakward_hook
and register_forward_hook
for nn.Module
.
I have some queries about register_backward_hook
. This is the sample code I am using
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(4, 1)
def forward(self, x):
o = self.fc1 (x)
o = self.relu1 (o)
o = self.fc2 (o)
return o
forward_values = {}
backward_values = {}
# Define the forward hook function
def hook_fn_forward(module, inp, out):
forward_values[module] = {}
forward_values[module]["input"] = inp
forward_values[module]["output"] = out
def hook_fn_backward(module, inp_grad, out_grad):
backward_values[module] = {}
backward_values[module]["input"] = inp_grad
backward_values[module]["output"] = out_grad
model = Model()
modules = model.named_children()
for name, module in modules:
module.register_forward_hook(hook_fn_forward)
module.register_backward_hook(hook_fn_backward)
# batch size of 1 -> shape = (1,3)
x = torch.tensor([[1.0, 1.0, 1.0]])
o = model(x)
o.backward()
I was checking the gradients of input and output for layer self.fc2
(last layer). The output looks like this
Output gradient 1
makes sense. Why is the input gradient is tuple
and why it has 3 values?
Can anyone explain what these depict?
I am using PyTorch version 1.7.1