Hook on a tensor in forward pass of a module?

Given the following code:

class Example(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(1,3)
    self.l2 = nn.Linear(3,2)
  
  def forward(self, x):
    mid = self.l1(x)
    res = self.l2(mid)
    return res

is there a way to have a hook the the mid variable in the forward call?
something like:

example.forward.mid.register_hook(lambda x: print(x))
1 Like

Hi,

Why not add mid.register_hook(lambda x: print(x)) to your forward function?
You can also return mid from the function so that the caller can add the hook.

I was going through a code of someone else, this is what i ended up doing.

How about example.l1.register_hook(lambda x: print(x)), since example.l1 is also an nn.Module

register_hook do not exist for nn.Module. And the closest thing that exists, register_backward_hook() is not working properly at the moment and should not be used.

register_forward_hook takes more complex functions than lambda x: print(x) as input. Try this example.

import torch
import torch.nn as nn

class Example(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1,3)
        self.l2 = nn.Linear(3,2)
    
    def forward(self, x):
        mid = self.l1(x)
        res = self.l2(mid)
        return res


if __name__ == "__main__":
    model = Example()
    dummy_input = torch.randn([1])
    print(dummy_input)

    def hook(module, input, output):
        print(output)
    
    model.l1.register_forward_hook(hook)
    out = model(dummy_input)