Hi! I observed an interesting behavior of register_forward_hook
but am not sure if it is indeed the case, so I am posting it here in the hope that someone will help clarify this to me.
First, I initialized a model and loaded a checkpoint from a previous run for inference:
best_model = my_model(args)
checkpoint = torch.load(args.model_save_dir+'/best_model.pt')
best_model.load_state_dict(checkpoint)
best_model = best_model.cpu()
Then, I tried to make predictions by passing the model to another method:
get_all_preds(best_model, test_batch, other_args)
This get_all_preds
function contains a get_activation
function and a dictionary prepared to receive attention weights from the last attention layer:
activation = {}
def get_activation(name, activation):
def hook(model, input, output):
activation[name] = output
return hook
best_model.transformer.transformer_encoder.layers[-1].self_attn.register_forward_hook(get_activation('last-layer-attention', activation))
best_model.eval()
preds = best_model(test_batch)
<Breakpoint>
I thought I would get an activation
with some attention weights, but the activation
is empty at the breakpoint.
HOWEVER, if I simply move the code snippet inside get_all_preds
to the same method appended after the place best_model
where is instantiated, I can get attention weights without problem!
I.e.:
best_model = my_model(args)
checkpoint = torch.load(args.model_save_dir+'/best_model.pt')
best_model.load_state_dict(checkpoint)
best_model = best_model.cpu()
activation = {}
def get_activation(name, activation):
def hook(model, input, output):
activation[name] = output
return hook
best_model.transformer.transformer_encoder.layers[-1].self_attn.register_forward_hook(get_activation('last-layer-attention', activation))
best_model.eval()
preds = best_model(test_batch)
<Breakpoint>
My question is, why passing the model to another method broke the hook’s pipeline?