Question about behavior of Register_forward_hook in loops

Hi there.

I tried to implement Grad-CAM with Register_forward_hook, but ran into a problem when I let the loop process estimate test data from Dataloader.

If I apply the following hook with Register_forward_hook

def forward_hook(module, inputs, outputs):
    global feature
    feature = outputs[0].clone()

forward_handle = target_layer.register_forward_hook(forward_hook)

I thought that the “feature” would be updated by the forward propagation on each loop, but it was not updated and the “feature” based on the first estimated data was fixed and output.

Therefore, I incorporated the forward_handle definition expression into the loop to achieve the behavior I wanted, but I don’t think this is a smart way to go about it.

Is this behavior a specification of register_forward_hook? If I am doing it wrong I would like your advice.

Thank you.

1 Like

Yes, that’s the case. The registered forward hook will be triggered in each forward pass as seen here:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer = nn.Identity()
        
    def forward(self, x):
        x = self.layer(x)
        return x


activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook


model = MyModel()
model.layer.register_forward_hook(get_activation('identity'))

x = torch.zeros(1, 10)
output = model(x)
print(activation['identity'])
# tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

x = torch.ones(1, 10)
output = model(x)
print(activation['identity'])
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

I’m unsure what might be causing your issue, so please feel free to post a minimal, executable code snippet reproducing your issue.

1 Like

Thank you for reply.

Thank you for the clarification.
I understand more about Pytorch now.

I have solved my issue by myself.
The cause was a very trivial mistake in fixing the index number when retrieving the output of the intermediate layer stored in a list object.