I need to use register_forward_hook
method to receive some layers output but I will not receive the results when I call .forward()
for any module however, when I change it to just calling that module like linear1()
, I get the results on hooks. Is that a bug?
This does not work:
from torch import nn
class cls_model(nn.Module):
def __init__(self):
super(cls_model,self).__init__()
self.l1 = nn.Linear(3,2)
self.l2 = nn.Linear(2,4)
self.l3 = nn.Linear(4,5)
def forward(self, x):
y1 = self.l1.forward(x)
y2 = self.l2.forward(y1)
y3 = self.l3.forward(y2)
return y3
def forward_hook(module, input, output):
print(f"Inside forward hook for {module.__class__.__name__}")
print(f"Input shape: {input[0].shape}")
print(f"Output shape: {output.shape}")
print("--------")
model = cls_model()
hook_handle = model.l2.register_forward_hook(forward_hook)
x = torch.randn((6,3))
model.forward(x)
This works
from torch import nn
class cls_model(nn.Module):
def __init__(self):
super(cls_model,self).__init__()
self.l1 = nn.Linear(3,2)
self.l2 = nn.Linear(2,4)
self.l3 = nn.Linear(4,5)
def forward(self, x):
y1 = self.l1(x)
y2 = self.l2(y1)
y3 = self.l3(y2)
return y3
def forward_hook(module, input, output):
print(f"Inside forward hook for {module.__class__.__name__}")
print(f"Input shape: {input[0].shape}")
print(f"Output shape: {output.shape}")
print("--------")
model = cls_model()
hook_handle = model.l2.register_forward_hook(forward_hook)
x = torch.randn((6,3))
model.forward(x)