CNN visualisation hooks

Hey in CNN visualisation toolkit : I am unable to understand :

def hook_layer(self):
        def hook_function(module, grad_in, grad_out):
            # Gets the conv output of the selected filter (from selected layer)
            self.conv_output = grad_out[0, self.selected_filter]

        # Hook the selected layer

What are the hooks functions and ‘grad_out’ doing? Grad_out has been not defined anywhere yet the code works fine. Also later this is called as :

def visualise_layer_with_hooks(self):
        # Hook the selected layer

without any arguments passed for grad_in or out. What is this line doing and how?

The naming is a bit misleading as grad_in and grad_out are used in backward hooks.
In forward hooks the vanilla naming would just be input and output.

You are basically creating a function named hook_function with a specific signature which is expected by register_forward_hook.

register_forward_hook makes sure to call the function you’ve passed with two arguments, the input and output of the nn.Module you’ve registered it to.
This is done automatically, so you don’t actually see in your code where input and output is created.

The last line just tries to register the current selected_layer to hook_function.
selected_layer has to be set beforehand or should have a default value otherwise.

Is grad_out a predefined function in torch? I could not find it in the docs.

But self.conv_output is not taken in by register_forward_hook(as hook_function does not return in)

No, they are just variables. It’s actually not grad_in and grad_out, but input and output in the forward function.

You could also name them a and b. The important fact is, that register_forward_hook needs a function with a signature of getting two arguments. The first argument is the input to this layer, the second its output.

self.conv_output just saves the activation of the first sample in the batch and self.selected_filter.
It’s a member of your class to visualize the activation. Later you may call my_class.conv_output to visualize the activation map.