Can I register forward hooks for functional calls?

As we know, we can register forward hooks for nn.Module instances. But there are functional operations like nn.functional.interpolate(), or torch.cat(), or even implicit callable like <built-in function add> for element-wise summation of tensor_a and tensor_b like this:

tensor_c = tensor_a + tensor_b  # There is a <built-in function add>  here

Can we use hooks for these functional calls?

I don’t think that’s possible, since the forward hooks are called in the nn.Module.__call__ method before dispatching to the forward method.
Since you are using the functional API you already have the output and input activations and could e.g. store them directly if needed. If hooks are more convenient to use instead of the tensors directly, you would probably have to wrap your calls into a custom module.

I cannot alter the original module definitions because I am implementing a tool that has to accept any popular backbone structures like ResNet, googleNet, denseNet, etc.
Luckily, I have successfully indirectly figured out shapes of inputs and outputs of functional calls by recursively searching in torch.fx.graph.Graph. But thank you anyway.

1 Like