Hi torch experts,
I am trying to add a forward hook to my model. But I got the error message indicating the hook won’t work with jit module.
Is there a particular reason why hook won’t work with jit modules? How can I work around this while still keep my module as a script module?
Thanks in advance!
register_backward_hooks) currently aren’t supported in TorchScript. If you’d like to see them added, please file a feature request on GitHub.
In your particular case it sounds like you’re inheriting from
ScriptModule as the way to access the TorchScript compiler. An API change in v1.2.0 lets you compile
nn.Modules without making them inherit from
ScriptModule directly, see these docs for details:
import torch.nn as nn
self.conv = nn.Conv2d(5, 5, 2)
def forward(self, x):
return self.conv(x) + 10
def my_hook(self, *args):
print("Hello from my_hook")
m = M()
# `my_hook` will be called
m(torch.randn(5, 5, 2, 2))
a_scripted_module = torch.jit.script(m)
# `my_hook` will NOT be called, forward hooks are lost
# when an `nn.Module` is compiled
a_scripted_module(torch.randn(5, 5, 2, 2))