Type mismatch with hooks input when using JIT

Hi,

I’m trying to register hooks in order to get the layers’ activation values in my model.
It does work with normal python runtime (like in this example).

However I cannot make it work in JIT:
As questioned here the type of “input” in the hook function is a tuple. And the Jit compiler does not like it:

 Traceback (most recent call last):
  File "main.py", line 22, in <module>
    script_net = torch.jit.script(net)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_script.py", line 1258, in script
    obj, torch.jit._recursive.infer_methods_to_compile
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_recursive.py", line 451, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_recursive.py", line 520, in create_script_module_impl
    create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
  File "/home/zodiac/.venv/ia/lib64/python3.7/site-packages/torch/jit/_recursive.py", line 377, in create_hooks_from_stubs
    concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)

RuntimeError: Hook 'hook' on module 'Linear' expected the input argument to be typed as a Tuple but found type: 'Tensor' instead.

This error occured while scripting the forward hook 'hook' on module Linear. If you did not want to script this hook remove it from the original NN module before scripting. 
This hook was expected to have the following signature: hook(self, input: Tuple[Tensor], output: Tensor). 
The type of the output arg is the returned type from either the forward method or the previous hook if it exists. 
Note that hooks can return anything, but if the hook is on a submodule the outer module is expecting the same return type as the submodule's forward.

Here’s the minimum code needed to reproduce this issue:

import torch


class NN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(1, 12)

    def forward(self, x):
        return self.l1(x)


def hook(model, input, output):
    pass


net = NN()
net.l1.register_forward_hook(hook)
script_net = torch.jit.script(net)

Any ideas? :slight_smile:
I’m on fedora 33, using Python 3.7.12 and Torch 1.10.0

Have a good day!

1 Like

Same on macos 12.0.1 using python 3.9 and torch 1.10 on rosetta!

Same on Ubuntu 18.04 using python 3.8 and torch 1.10

Hey @Zodiac @caillonantoine, we need type hints to make it works properly.

from typing import Tuple

...

def hook(model, input: Tuple[torch.Tensor], output):
    pass

...
1 Like