You code has 1 and 3 without any 2 in between. This is wrong, either because you forgot 2 or because you think you are coding only 3 but in reality are doing 1+3 and not undoing the effect of step 1 before.
You could print m._forward_hooks at various times in your code to get a feeling of what your code does w.r.t. forward hooks. (Note that m._forward_hooks is an internal datastructure, so no guarantees for that staying viable, but for now it seems to be OK for educational purposes.)
The order of things is
register hook, keep handle,
make use of the hook,
remove hook.
You code has 1 and 3 without any 2 in between. This is wrong, either because you forgot 2 or because you think you are coding only 3 but in reality are doing 1+3 and not undoing the effect of step 1 before.
You could print m._forward_hooks at various times in your code to get a feeling of what your code does w.r.t. forward hooks. (Note that m._forward_hooks is an internal datastructure, so no guarantees for that staying viable, but for now it seems to be OK for educational purposes.)
Best regards
Thomas
Hi Tom,
Thanks for your kind reply.
I think I am using the hook:
hooks1 = SimilarityHook.create_hooks(model1, ['name1', ...])
hooks2 = SimilarityHook.create_hooks(model2, ['name1', ...])
with torch.no_grad():
model1(input)
model2(input)
[[hook1.distance(hook2) for hook2 in hooks2] for hook1 in hooks1]
when I run the forward pass the hook registers the tensors and then later uses those to compute “distances between models”.
Perhaps a full example will make more sense (with your numbering):
# - register hooks
hooks1 = SimilarityHook.create_hooks(model1, ['name1', ...])
hooks2 = SimilarityHook.create_hooks(model2, ['name1', ...])
# - use hooks
with torch.no_grad():
model1(input)
model2(input)
dists: list[float] = [[hook1.distance(hook2) for hook2 in hooks2] for hook1 in hooks1]
print(dists)
[[hook1.clear(), hook2.clear() for hook2 in hooks2] for hook1 in hooks1] # clear stored tensors
# - unregister *all* hooks
remove_hooks(mdl1, hooks1)
remove_hooks(mdl2, hooks2)
I think that is a complete example.
Is this correct now? I suppose I am most interested that step 3 works.