How to check where the hooks are in the model?

is there a way to check the model and know where the hooks are located?

1 Like

There may be a better way, but the following one will work, which directly looks up module’s __dict__:

import torch.nn as nn


class Foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 2, 3)

foo = Foo()

def hook(m, i, o):
    print(m, i, o)

def pre_hook(m, i):
    print(m, i)

foo.register_forward_hook(hook)
foo.conv.register_forward_hook(hook)
foo.register_forward_pre_hook(pre_hook)

save = []
def find_hook(m):
    module_name = type(m).__name__
    for k, v in m._forward_hooks.items():
        save.append((module_name, k, v.__name__))
        
    for k, v in m._forward_pre_hooks.items():
        save.append((module_name, k, v.__name__))
        
    for k, v in m._backward_hooks.items():
        save.append((module_name, k, v.__name__))

foo.apply(find_hook)
print(save)

which prints out

[('Conv2d', 1, 'hook'), ('Foo', 0, 'hook'), ('Foo', 2, 'pre_hook')]

@sio277 Thanks for your suggestion, I have not find that solution works correctly for my case.
Sorry I should have provided a more descriptive example/snippet of code.
lets say I have the following code.
how can I know where I added the hooks?


# ResNet Class
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        
        # define the resnet152
        # self.resnet = resnet50(pretrained=True)
        self.resnet = models.resnet50(pretrained=True)

        
        # isolate the feature blocks
        self.features = nn.Sequential(self.resnet.conv1,
                                      self.resnet.bn1,
                                      nn.ReLU(),
                                      nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
                                      self.resnet.layer1, 
                                      self.resnet.layer2, 
                                      self.resnet.layer3, 
                                      self.resnet.layer4)
        
        # average pooling layer
        self.avgpool = self.resnet.avgpool
        
        # classifier
        self.classifier = self.resnet.fc
        
        # gradient placeholder
        self.gradient = None
    
    # hook for the gradients
    def activations_hook(self, grad):
        self.gradient = grad
    
    def get_gradient(self):
        return self.gradient
    
    def get_activations(self, x):
        return self.features(x)
    
    def forward(self, x):
        
        # extract the features
        x = self.features(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # complete the forward pass
        x = self.avgpool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        
        return x
resnet = ResNet()
img = torch.rand(1,3,224,224)
# forward pass
pred = resnet(img)
pred.argmax(dim=1) #. # prints tensor([2])


# get the gradient of the output with respect to the parameters of the model
pred[:, 2].backward()

# pull the gradients out of the model
gradients = resnet.get_gradient()

# pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# get the activations of the last convolutional layer
activations = resnet.get_activations(img).detach()

# weight the channels by corresponding gradients
for i in range(512):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()


Oh, so you attach the hook to the forward tensor, to get a gradient at self.features(x) tensor. But what do you mean ‘how can I know where I added the hooks’ ?

I mean, lets say I have that model that I attached the hook to the tensor in forward pass, let say now I dont know where did I attach the hook (so basically I dont want to go and look into the model desig);
So now I load the model as resnet = ResNet() , is there a way to print where the hook is attached at this stage?
I understand that is kinda not very smart, but it is just sanity checking where the hooks are attached without having to look into the model.
will be a way to know where the hooks are attacked after I do resnet = ResNet()?
basically im wondering how I can show where the hooks are attacked to the tensor (or modules) by using resnet

Any suggestions? @albanD @ptrblck

What do you mean by where it is attached? Like which line of code called “register_*_hook()”?
Or which nn.Module has a hook associated with it?

@albanD the latter one, which nn.Module has a hook associated with it. or even in a more general case which nn.Module or tensor has a hook attach to it.

Not sure why you mention the former case, but just out of curiosity, would it even make scenes or have any usage to know what line of code calls “register_*_hook()”?

Not sure no :smiley: Maybe you have a big library that has some hooks that you want to remove but you don’t know who added them. haha

Both for Tensors and Modules there are no public API to know if there are hooks on them I’m afraid.
But currently (version 1.8), you can check:

  • Global Module hooks via `torch.nn.modules.module.global{forward,forward_pre,backward}_hooks.
  • Per-Module hooks via your_mod._{forward,forward_pre,backward}_hooks.
  • Per-Tensor hooks via your_tensor._backward_hooks (there are no forward hooks on Tensor)
3 Likes