How should be used register_forward_hook?

I need to use register_forward_hook method to receive some layers output but I will not receive the results when I call .forward() for any module however, when I change it to just calling that module like linear1(), I get the results on hooks. Is that a bug?

This does not work:

from torch import nn

class cls_model(nn.Module):
    def __init__(self):
        super(cls_model,self).__init__()
        self.l1 = nn.Linear(3,2)
        self.l2 = nn.Linear(2,4)
        self.l3 = nn.Linear(4,5)
    def forward(self, x):
        y1 = self.l1.forward(x)
        y2 = self.l2.forward(y1)
        y3 = self.l3.forward(y2)
        return y3
    
def forward_hook(module, input, output):
    print(f"Inside forward hook for {module.__class__.__name__}")
    print(f"Input shape: {input[0].shape}")
    print(f"Output shape: {output.shape}")
    print("--------")
model = cls_model()
hook_handle = model.l2.register_forward_hook(forward_hook)
x = torch.randn((6,3))
model.forward(x)

This works

from torch import nn

class cls_model(nn.Module):
    def __init__(self):
        super(cls_model,self).__init__()
        self.l1 = nn.Linear(3,2)
        self.l2 = nn.Linear(2,4)
        self.l3 = nn.Linear(4,5)
    def forward(self, x):
        y1 = self.l1(x)
        y2 = self.l2(y1)
        y3 = self.l3(y2)
        return y3
    
def forward_hook(module, input, output):
    print(f"Inside forward hook for {module.__class__.__name__}")
    print(f"Input shape: {input[0].shape}")
    print(f"Output shape: {output.shape}")
    print("--------")
model = cls_model()
hook_handle = model.l2.register_forward_hook(forward_hook)
x = torch.randn((6,3))
model.forward(x)
1 Like

No, this is the exact reason why you should call the module directly instead of its .forward pass, as e.g. hooks will be skipped otherwise.

2 Likes