Register_forward_hook is not supported on ScriptModules

Hi torch experts,

I am trying to add a forward hook to my model. But I got the error message indicating the hook won’t work with jit module.

Is there a particular reason why hook won’t work with jit modules? How can I work around this while still keep my module as a script module?

Thanks in advance!

1 Like

register_forward_hooks (and register_backward_hooks) currently aren’t supported in TorchScript. If you’d like to see them added, please file a feature request on GitHub.

In your particular case it sounds like you’re inheriting from ScriptModule as the way to access the TorchScript compiler. An API change in v1.2.0 lets you compile nn.Modules without making them inherit from ScriptModule directly, see these docs for details:

import torch
import torch.nn as nn

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(5, 5, 2)

    def forward(self, x):
        return self.conv(x) + 10


def my_hook(self, *args):
    print("Hello from my_hook")


m = M()
m.conv.register_forward_hook(my_hook)

# `my_hook` will be called
m(torch.randn(5, 5, 2, 2))

a_scripted_module = torch.jit.script(m)

# `my_hook` will NOT be called, forward hooks are lost
# when an `nn.Module` is compiled
a_scripted_module(torch.randn(5, 5, 2, 2))