def extract_features(self, t_model, s_model, t_model_cfg, s_model_cfg):
global t_FM
t_FM = []
global s_FM
s_FM = []
def hook1(module1, input, output):
print("hooker Working")
t_FM.append(output)
handle = t_model.tbackbone.layer2[1].conv1.register_forward_hook(hook1)
handle.remove()
def hook2(module2, input, output):
print("hooker Working")
s_FM.append(output)
handle = s_model.sbackbone.layer2[1].conv1.register_forward_hook(hook2)
handle.remove()
return t_FM, s_FM
Remove the hook handle after fire it (after forwarding).
Hello @sio277 ,
can you explain me more.
thank you.
The registered hooks were removed before they are called. Try something like this:
t_FM = []
s_FM = []
hooks = []
def hook1(module1, input, output):
print("hooker Working")
t_FM.append(output)
hooks.append(t_model.tbackbone.layer2[1].conv1.register_forward_hook(hook1))
def hook2(module2, input, output):
print("hooker Working")
s_FM.append(output)
hooks.append(s_model.sbackbone.layer2[1].conv1.register_forward_hook(hook2))
your_model(input_tensors) # forward with firing all the hooks registered
for hook_handle in hooks:
hook_handle.remove()
Also, you don’t need to set t_FM
and s_FM
as global (nonlocal) variables .
1 Like
Check the source code of nn.Module.__call__()
, where all the hooks are fired with forward()
.
1 Like
Thank you so much @sio277
I global them because of i want to use them in another def