How do hooks work?

I’m trying to implement the GradCam paper which uses the gradient information flowing into the last convolutional layer of the CNN to assign importance values to each neuron for a particular decision of interest.

The paper says -

1. The first step to implementing GradCam would be obtaining the gradient wrt to the activation maps. My question is how do it obtain it? From what I’ve read online, we need hooks to get it. My question is: how is the value obtained using hooks different from the cnnlayer.weight.grad?

2. Now, let’s say we need hooks. How do we actually use hooks here? The docs say that there we can use the hooks on either tensors or modules. My question is, how are they different? Does a module refer to a single layer such as nn.Conv2d(...) or does it refer to the entire class shown below.

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.cnn1 = nn.Conv2d()
        self.cnn2 = nn.Conv2d()
        # cnn_out gives a (1,n) output where n is the number of classes.
        # We use cnn_out instead of a fc layer.
        self.cnn_out = nn.Conv2d()

    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        out = self.cnn_out(x)

        return out

ie. does module refer to CNNModel or nn.Conv2d()?

3. We can use hooks on a tensor by way of register_hook(). How do we use it? From what I’ve read online, we need to use hooks on a tensor to implement GradCAM. Let’s say my setup if the following -

I have a custom image classifier as follows. How do I use register_hook to obtain the gradients of self.cnn2? I’ve seen examples where register_hook is used inside the forward() method (Source). Sometime, people make a separate class out of it (Source (pg132)).

What is the correct way of doing it? Assume the trained model is stored in variable model as shown below.

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.cnn1 = nn.Conv2d()
        self.cnn2 = nn.Conv2d()
        # cnn_out gives a (1,n) output where n is the number of classes.
        # We use cnn_out instead of a fc layer.
        self.cnn_out = nn.Conv2d()

    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)

        # we return a (1,10) output because we have 10 classes in MNIST.
        out = self.cnn_out(x)

        return out

# Let's assume this model has been trained on MNIST (28 x 28 images)
model = CNNModel()

# Obtain prediction for a single image 
pred = model(single_test_img)

What do we do after we obtain the pred. Can you please show that by completing the pseudo-code I’ve written. Like, where does register_hook (define and run it) and everything fit in while we try to obtain the gradient mentioned above (at the top)?

4. Can we use register_backward_hook() or register_forward_hook() somehow? What I mean is, can these 2 functionalities by used to get the same output as register_hook()?

Can you give a small example of how we use register_forward_hook() using the CNNModel class I defined above?

I’ve also seen the docs say that register_backward_hook() is buggy, so we are advised not to use it.

I am confused by how hooks work. Hooks in PyTorch do not receive a lot of attention (docs or blogs). Please feel free to answer individual chunks, I’ll follow up in the comments. :slight_smile:


  1. Hooks can allow you to get gradients wrt to intermediate results in the forward, not just the weights/biases.

  2. For backward hooks, you should only use the Tensor version right now (the nn.Module version is not working properly at the moment).
    To you a module hook for the output of cnn1 for example, you would modify your forward method as follow:

# Defined somewhere else
def my_hook(grad):
    # Do what you want with the grad

# In your model definition
    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        out = self.cnn_out(x)

        return out
  1. As in the example above.
Hey @albanD, thanks for taking the time to answer. I have a few follow up questions.

  1. What do you mean when you say intermediate results? Gradients are calculated for a weight (dL/dW). By intermediate results do you mean dL/dA where A is the activation for applied for a layer which has a weight W? Can show what these intermediate results are in the following image please?

  2. How is register_hook applied on a tensor different from register_backward_hook() applied on a nn.Module? What is nn.Module here - a layer such as nn.Conv2d or an entire class like CNNModel()? I have this confusion because one example in the doc has this net.conv2.register_backward_hook(printnorm) (used on a layer conv2) but modules are supposed to be classes that have a .forward() function. Or is my understanding of a nn.Module incorrect?

Can you highlight where register_hook(), register_forward_hook(), register_backward_hook() apply in the attached image please?

In the code example you wrote, register_hook() is applied after self.cnn1. So, the hook (it’s like a callback function right?) can now access the gradients flowing into self.cnn1 during backprop?


  1. In your image, the intermediate results will be z and a. And the correpsonding gradients for them are dL/dz and dL/da.
    Note that in your image, all the boxes correspond to functions. The purple function is what you define in your forward and the red functions are our backward Node.
    The Tensor containing values are on the arrows: the purple ones are the intermediate results you have when you write your forward.
    The green are what is saved for backward (you don’t see these).
    And the red are the gradients flowing back, containing the gradient for each intermediate result.
    You can see a Tensor hook as observing one of these red edges.

  2. nn.Module builds on top of the autograd to have a nice neural network library.
    So a hook on the nn.Module is trying (and failing at the moment) to get the intermediate gradient for all the inputs/outputs of it’s forward function.
    Hooking on a Tensor is using directly the autograd construct. And will give you the gradient computed for that specific Tensor.