Ordering of elements in pytorch module's backward hook?


I’ve been trying to use the hook functionality in pytorch and have run into a few questions about it.

To begin with, here are both the signature and docstring for the register_backward_hook function

def register_backward_hook(self, hook): 
    The hook will be called every time the gradients with respect to module
    inputs are computed. The hook should have the following signature::

       hook(module, grad_input, grad_output) -> Tensor or None

    The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
    module has multiple inputs or outputs. The hook should not modify its
    arguments, but it can optionally return a new gradient with respect to
    input that will be used in place of :attr:`grad_input` in subsequent

            a handle that can be used to remove the added hook by calling

    .. warning ::

        The current implementation will not have the presented behavior
        for complex :class:`Module` that perform many operations.
        In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
        contain the gradients for a subset of the inputs and outputs.
        For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
        directly on a specific input or output to get the required gradients.


As stated, the expected signature for a hook function (for a module) is hook(module, grad_input, grad_output). grad_input and grad_output appear to be tuples whenever a node has multiple inputs.

My question is this: what is the structure of the tuple of elements in grad_input? For a basic MNIST classification net that I’ve set up, it appears to be: (grad_input_to_layer, grad_weights, grad_bias) respectively, where grad_input_to_layer is the gradients with respect to the input of the layer, grad_weights is the gradient with respect to the weights of the layer, and grad_bias is the gradient with respect to the layer’s bias.

Is there a specific guarantee for what the structure of the grad_input and grad_output tuples are?

The module backward isn’t really what you expect.

Best regards


Thanks Tom. Do you have a recommendation as to how to proceed then for our use case? I would simply like to figure out what is being logged and in what order, but the link seems to suggest the hook is “broken”?

Yeah, well the hook is broken. :frowning:
But this in particular means that the structure just comes from the last operation in the forward of the module.
For some modules (conv, linear that doesn’t use broadcasting), that is OK, for others not so much.

Best regards