Register hook for all operators, including e.g., simple add

I want to call a function on each operator of the forward pass of my NN.
So naturally, I am registering a forward hook.
However, I am currently doing something like this:

for name, module in net.named_modules():
    module.register_forward_hook(hook_fn)

and it only registers a hook for named module, e.g., Conv2d, but not for simple operators, e.g., + or torch.cat().
I want this hook to be called on every operator node, including simple ones like + or torch.cat().
Is there a way to do so?

Thank you.

Hi,

I am afraid these functions are at a much lower level than torch.nn and you won’t be able to use a direct hooking system.
One thing you can try to do is wrap your input in a Tensor-like object and write a custom implementation that calls your hook and then delegate to the original pytorch implementation: Extending PyTorch — PyTorch 1.8.1 documentation

1 Like