How to remove multiple hooks?

is this correct for removing multiple hooks?

def remove_hook(mdl: nn.Module, hook):
    """
    ref: https://github.com/pytorch/pytorch/issues/5037
    """
    handle = mdl.register_forward_hook(hook)
    handle.remove()


def remove_hooks(mdl: nn.Module, hooks: list):
    """
    ref: https://github.com/pytorch/pytorch/issues/5037
    """
    for hook in hooks:
        remove_hook(mdl, hook)

my code example:

        >>> hooks1 = SimilarityHook.create_hooks(model1, ['name1', ...])
        >>> hooks2 = SimilarityHook.create_hooks(model2, ['name1', ...])
        >>> with torch.no_grad():
        >>>    model1(input)
        >>>    model2(input)
        >>> [[hook1.distance(hook2) for hook2 in hooks2] for hook1 in hooks1]
    handle = mdl.register_forward_hook(hook)
    handle.remove()

No, you need to keep the handles from the registration and then just call handle.remove.

apologies, not sure what you mean. What is wrong with the code I proposed?

Thanks for your time in advance.

The order of things is

  1. register hook, keep handle,
  2. make use of the hook,
  3. remove hook.

You code has 1 and 3 without any 2 in between. This is wrong, either because you forgot 2 or because you think you are coding only 3 but in reality are doing 1+3 and not undoing the effect of step 1 before.

You could print m._forward_hooks at various times in your code to get a feeling of what your code does w.r.t. forward hooks. (Note that m._forward_hooks is an internal datastructure, so no guarantees for that staying viable, but for now it seems to be OK for educational purposes.)

Best regards

Thomas

The order of things is

register hook, keep handle,
make use of the hook,
remove hook.
You code has 1 and 3 without any 2 in between. This is wrong, either because you forgot 2 or because you think you are coding only 3 but in reality are doing 1+3 and not undoing the effect of step 1 before.

You could print m._forward_hooks at various times in your code to get a feeling of what your code does w.r.t. forward hooks. (Note that m._forward_hooks is an internal datastructure, so no guarantees for that staying viable, but for now it seems to be OK for educational purposes.)

Best regards

Thomas

Hi Tom,

Thanks for your kind reply.

I think I am using the hook:

hooks1 = SimilarityHook.create_hooks(model1, ['name1', ...])
hooks2 = SimilarityHook.create_hooks(model2, ['name1', ...])
with torch.no_grad():
    model1(input)
    model2(input)
[[hook1.distance(hook2) for hook2 in hooks2] for hook1 in hooks1]

when I run the forward pass the hook registers the tensors and then later uses those to compute “distances between models”.

Perhaps a full example will make more sense (with your numbering):

# - register hooks
hooks1 = SimilarityHook.create_hooks(model1, ['name1', ...])
hooks2 = SimilarityHook.create_hooks(model2, ['name1', ...])

# - use hooks
with torch.no_grad():
    model1(input)
    model2(input)
dists: list[float] = [[hook1.distance(hook2) for hook2 in hooks2] for hook1 in hooks1]
print(dists)
[[hook1.clear(), hook2.clear() for hook2 in hooks2] for hook1 in hooks1]  # clear stored tensors

# - unregister  *all* hooks
remove_hooks(mdl1, hooks1)
remove_hooks(mdl2, hooks2)

I think that is a complete example.

Is this correct now? I suppose I am most interested that step 3 works.

What does the printing of the hooks say?

Hi, Tom

Could you please help me with this?
https://discuss.pytorch.org/t/pytorch-hooks-removal/