is there a way to check the model and know where the hooks are located?
There may be a better way, but the following one will work, which directly looks up moduleās __dict__
:
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 3)
foo = Foo()
def hook(m, i, o):
print(m, i, o)
def pre_hook(m, i):
print(m, i)
foo.register_forward_hook(hook)
foo.conv.register_forward_hook(hook)
foo.register_forward_pre_hook(pre_hook)
save = []
def find_hook(m):
module_name = type(m).__name__
for k, v in m._forward_hooks.items():
save.append((module_name, k, v.__name__))
for k, v in m._forward_pre_hooks.items():
save.append((module_name, k, v.__name__))
for k, v in m._backward_hooks.items():
save.append((module_name, k, v.__name__))
foo.apply(find_hook)
print(save)
which prints out
[('Conv2d', 1, 'hook'), ('Foo', 0, 'hook'), ('Foo', 2, 'pre_hook')]
@sio277 Thanks for your suggestion, I have not find that solution works correctly for my case.
Sorry I should have provided a more descriptive example/snippet of code.
lets say I have the following code.
how can I know where I added the hooks?
# ResNet Class
class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
# define the resnet152
# self.resnet = resnet50(pretrained=True)
self.resnet = models.resnet50(pretrained=True)
# isolate the feature blocks
self.features = nn.Sequential(self.resnet.conv1,
self.resnet.bn1,
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
self.resnet.layer1,
self.resnet.layer2,
self.resnet.layer3,
self.resnet.layer4)
# average pooling layer
self.avgpool = self.resnet.avgpool
# classifier
self.classifier = self.resnet.fc
# gradient placeholder
self.gradient = None
# hook for the gradients
def activations_hook(self, grad):
self.gradient = grad
def get_gradient(self):
return self.gradient
def get_activations(self, x):
return self.features(x)
def forward(self, x):
# extract the features
x = self.features(x)
# register the hook
h = x.register_hook(self.activations_hook)
# complete the forward pass
x = self.avgpool(x)
x = x.view((1, -1))
x = self.classifier(x)
return x
resnet = ResNet()
img = torch.rand(1,3,224,224)
# forward pass
pred = resnet(img)
pred.argmax(dim=1) #. # prints tensor([2])
# get the gradient of the output with respect to the parameters of the model
pred[:, 2].backward()
# pull the gradients out of the model
gradients = resnet.get_gradient()
# pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# get the activations of the last convolutional layer
activations = resnet.get_activations(img).detach()
# weight the channels by corresponding gradients
for i in range(512):
activations[:, i, :, :] *= pooled_gradients[i]
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()
Oh, so you attach the hook to the forward tensor, to get a gradient at self.features(x)
tensor. But what do you mean āhow can I know where I added the hooksā ?
I mean, lets say I have that model that I attached the hook to the tensor in forward pass, let say now I dont know where did I attach the hook (so basically I dont want to go and look into the model desig);
So now I load the model as resnet = ResNet()
, is there a way to print where the hook is attached at this stage?
I understand that is kinda not very smart, but it is just sanity checking where the hooks are attached without having to look into the model.
will be a way to know where the hooks are attacked after I do resnet = ResNet()
?
basically im wondering how I can show where the hooks are attacked to the tensor (or modules) by using resnet
What do you mean by where it is attached? Like which line of code called āregister_*_hook()ā?
Or which nn.Module has a hook associated with it?
@albanD the latter one, which nn.Module
has a hook associated with it. or even in a more general case which nn.Module
or tensor
has a hook attach to it.
Not sure why you mention the former case, but just out of curiosity, would it even make scenes or have any usage to know what line of code calls āregister_*_hook()ā?
Not sure no Maybe you have a big library that has some hooks that you want to remove but you donāt know who added them. haha
Both for Tensors and Modules there are no public API to know if there are hooks on them Iām afraid.
But currently (version 1.8), you can check:
- Global Module hooks via `torch.nn.modules.module.global{forward,forward_pre,backward}_hooks.
- Per-Module hooks via
your_mod._{forward,forward_pre,backward}_hooks
. - Per-Tensor hooks via
your_tensor._backward_hooks
(there are no forward hooks on Tensor)