Forward hook does not work with __call__ method

.register_forward_hook for a subclass of nn.Module does not work when __call__ is used instead of forward.

This code works:

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

def print_forward(self, x_in, x_out):
    print('forward hook')

eye = Identity()
eye.register_forward_hook(print_forward)
x = eye(torch.rand(10)) # prints "forward hook"

and this one doesn’t (just change forward to __call__)

# this doesn't
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def __call__(self, x):
        return x

def print_forward(self, x_in, x_out):
    print('forward hook')

eye = Identity()
eye.register_forward_hook(print_forward)
x = eye(torch.rand(10)) # no print here

Is this a design decision? It would be nice to have a warning message in that case.

It is a design decision, since the __call__ method is registering the hooks and calls forward as seen here.

Since you’ve created your own __call__ method, this behavior is disabled.