.register_forward_hook
for a subclass of nn.Module does not work when __call__
is used instead of forward
.
This code works:
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
def print_forward(self, x_in, x_out):
print('forward hook')
eye = Identity()
eye.register_forward_hook(print_forward)
x = eye(torch.rand(10)) # prints "forward hook"
and this one doesn’t (just change forward
to __call__
)
# this doesn't
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def __call__(self, x):
return x
def print_forward(self, x_in, x_out):
print('forward hook')
eye = Identity()
eye.register_forward_hook(print_forward)
x = eye(torch.rand(10)) # no print here
Is this a design decision? It would be nice to have a warning message in that case.