Correctly indexing the input tuple in a backward hook function


I’ve noticed something peculiar when using PyTorch backward hook functions. Perhaps this was designed so on purpose, but it makes using backward hooks a bit unintuitive.

First off, a backward hook function would look something like this:

def hook_fn(module, inputs, outputs):
    # something here

where the function’s parameters are the module name, inputs (tuple), and outputs (tuple).

This is the part that confuses me. The parameter inputs is a tuple, which contains gradients of the module, but the order of elements seems a bit… random. Not stochastic, but random. To elaborate, if we register this function on a nn.Linear layer with bias parameters, this tuple becomes:

(grad_bias, grad_preactivation, grad_weight) 

So why is the grad_preactivation (gradients for activation of the previous layer) between grad_bias and grad_weight?

What makes it worse is that for convolutional layers, this order is shuffled. For nn.Conv2d layers, the order seems to be:

(grad_preactivation, grad_weight, grad_bias)

Is there any documentation that explains this in depth (I can’t seem to find much :cry:), and are there any better ways to intercept gradients in the backward pass? For example, say I want to the modify the gradients of the weights in each layer, where some layers are nn.Conv2d and others are nn.Linear, and the bias could be either true of false. Then, would I need to hard-code my hook function to account for all possibilities, or is there a more elegant solution?

Thanks in advance :slight_smile:


Are you using Module.register_backward_hook()?
If so, this is a known issue and if you check the latest doc, they have been deprecated in favor of Module.register_full_backward_hook() that has the exact same API and doc (but actually do what the doc says they’re supposed to do :smiley: ).

Thanks for the reply.
I just tried out register_full_backward_hook(), and it looks like only the pre and post activation gradients are accessed. Thanks for your answer :grinning_face_with_smiling_eyes: :

On a side note:
I assume, for the gradients of model parameters, there is no need to access it mid-backprop? I guess these could be accessed by iterating over model.named_parameters() after the full backward pass.

Yes you can access them afterwards.
If you want to access them in a hook, you can directly register a Tensor hook on these parameters directly: mod.weight.register_hook(hook_fn).

1 Like

Ok. Thanks for the thorough reply :smiley: