How do I remove forward hooks on a module without the hook handles?

I can’t seem to figure out how to remove the hooks without the handles, despite being able to detect the hooks on the target module.

from typing import Tuple
import torch

model = torch.nn.Sequential(*[torch.nn.Identity()] * 4)

def forward_hook(module: torch.nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor) -> None:
    pass

def find_and_remove_hooks(m):
    print("Hooks")
    c = 0
    module_name = type(m).__name__
    print(m._forward_hooks)
    for k, v in m._forward_hooks.items():
        c+=1
        m._forward_hooks[v].remove() # Doesn't work
    print("All " + str(c) + " hooks found")

model[1].register_forward_hook(forward_hook)

find_and_remove_hooks(model[1])

Looks like using this code does nothing:

handle = torch.utils.hooks.RemovableHandle(module._forward_hooks)
handle.remove()

While, simply setting the value back to it’s default removes the hooks:

from collections import OrderedDict
from typing import Dict, Callable

module._forward_hooks: Dict[int, Callable] = OrderedDict()

All forward hooks can be removed like this:

from collections import OrderedDict
from typing import Dict, Callable
import torch

def remove_all_forward_hooks(model: torch.nn.Module) -> None:
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks: Dict[int, Callable] = OrderedDict()
            remove_all_forward_hooks(child)

And all the hooks can be removed like this, using the attributes from [here]:

from collections import OrderedDict
from typing import Dict, Callable
import torch

def remove_all_hooks(model: torch.nn.Module) -> None:
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks: Dict[int, Callable] = OrderedDict()
            elif hasattr(child, "_forward_pre_hooks"):
                child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
            elif hasattr(child, "_backward_hooks"):
                child._backward_hooks: Dict[int, Callable] = OrderedDict()
            remove_all_hooks(child)

This solution doesn’t seem to reset the global counter though as can be seen with the find_hooks function.