`register_forward_hook` doesn't work for NestedTensor

Hi! I observed an interesting behavior of register_forward_hook but am not sure if it is indeed the case, so I am posting it here in the hope that someone will help clarify this to me.

First, I initialized a model and loaded a checkpoint from a previous run for inference:

best_model = my_model(args)
checkpoint = torch.load(args.model_save_dir+'/best_model.pt')
best_model.load_state_dict(checkpoint)
best_model = best_model.cpu()

Then, I tried to make predictions by passing the model to another method:

get_all_preds(best_model, test_batch, other_args)

This get_all_preds function contains a get_activation function and a dictionary prepared to receive attention weights from the last attention layer:

activation = {}
def get_activation(name, activation):
    def hook(model, input, output):
        activation[name] = output
    return hook

best_model.transformer.transformer_encoder.layers[-1].self_attn.register_forward_hook(get_activation('last-layer-attention', activation))

best_model.eval()
preds = best_model(test_batch)

<Breakpoint>

I thought I would get an activation with some attention weights, but the activation is empty at the breakpoint.

HOWEVER, if I simply move the code snippet inside get_all_preds to the same method appended after the place best_model where is instantiated, I can get attention weights without problem!

I.e.:

best_model = my_model(args)
checkpoint = torch.load(args.model_save_dir+'/best_model.pt')
best_model.load_state_dict(checkpoint)
best_model = best_model.cpu()

activation = {}
def get_activation(name, activation):
    def hook(model, input, output):
        activation[name] = output
    return hook

best_model.transformer.transformer_encoder.layers[-1].self_attn.register_forward_hook(get_activation('last-layer-attention', activation))

best_model.eval()
preds = best_model(test_batch)

<Breakpoint>

My question is, why passing the model to another method broke the hook’s pipeline?

The reason why pytorch uses forward instad of the python’s default __call__ is that all this machinery is coded there.

When you call model(…) you are calling model.__call__(...) which internally calls the forward func.
The hooks are triggered in the __call__ func.
If you call forward externally you are skipping this.

Thanks for your response! I think I am calling model() in the same way, so why isn’t __call__ triggered when calling model(...) inside another method, when the model is declared in the main method?

It doesn’t make sense :slight_smile:
Could you make a reproducible script?
I would say you have some variable scope problem or some silly problem.

Through a controlled experiment, I think the problem is either in my model or in model.eval() + torch.no_grad.

When @torch.no_grad() is present-- When there’s no model.eval() inside the get_test method, it’s fine. When it exists, no activation recorded. Of course, the outputs of these two cases differ a lot because of the Dropout and Normalization layers in the model.

When @torch.no_grad() is not present-- Activation is always recorded.

Although indeed, this doesn’t make sense to me.

The following snippet shows a controlled experiment:

@torch.no_grad()
def get_test(model, test_batch):
    # model.eval()  # Comment and uncomment this
    return model(test_batch)

def get_activation(name, activation):
    def hook(model, input, output):
        activation[name] = output
    return hook

activation = {}
model = MyModel(args)
checkpoint = torch.load(some_path)
model.load_state_dict(checkpoint)
model = model.to(device)

model.transformer.transformer_encoder.layers[-1].self_attn.register_forward_hook(get_activation('last-layer-attention', activation))

print(get_test(model, all_batch))
print(activation)

To discard it’s a pytorch issue and not something related to the code of your model you could try to catch an activation for a toy model.

__call__ implementation

    def _call_impl(self, *input, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
            return forward_call(*input, **kwargs)
        # Do not call functions when jit is used
        full_backward_hooks, non_full_backward_hooks = [], []
        if self._backward_hooks or _global_backward_hooks:
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
        if _global_forward_pre_hooks or self._forward_pre_hooks:
            for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
                result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result

        bw_hook = None
        if full_backward_hooks:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)

        result = forward_call(*input, **kwargs)
        if _global_forward_hooks or self._forward_hooks:
            for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result

        if bw_hook:
            result = bw_hook.setup_output_hook(result)

        # Handle the non-full backward hooks
        if non_full_backward_hooks:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in non_full_backward_hooks:
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

        return result

Apparently there is no condition to skip the hooking system if no grad or eval mode. I’d say it should work in both cases.

I agree. It also has nothing to do with loading checkpoint / lazy initialization. Now the only problem seems to be the model itself… That’s a headache.

Ahhhhhh SHIT I got it! It’s because of the fast path implemented by torch.nn.TransformerEncoderLayer (see here). It looks like hook just doesn’t apply to torch.NestedTensor. When I change any of the premises (leading to a normal path), the hook works as usual!

Does anyone happen to know how to hook torch.NestedTensor? It almost took me 6 hours to figure out it is this fast path issue…

Oh it was that :expressionless:
Btw it would be nice if you open a PR about it to give a warning or document it so that people is aware.