I want to call a function on each operator of the forward pass of my NN.
So naturally, I am registering a forward hook.
However, I am currently doing something like this:
for name, module in net.named_modules():
module.register_forward_hook(hook_fn)
and it only registers a hook for named module, e.g., Conv2d, but not for simple operators, e.g., + or torch.cat().
I want this hook to be called on every operator node, including simple ones like + or torch.cat().
Is there a way to do so?
Thank you.