Register full backward hook vs. register module full backward hook

I am trying to understand the differences between the functions in the title. What are the appropriate use cases for both? I especially would like to see an example where register_module_full_backward_hook is used properly. Thank you!

(there’s actually a bug if you use nightly/build from source that needs to be fixed!)
edit: Actually I wasn’t aware that the bug was very recently reverted. So only certain nightlies would be broken.

Otherwise, this is how it should be used:

import torch

a = torch.rand(10, requires_grad=True)
lin = torch.nn.Linear(10, 10)

def hook(mod, grad_inputs, grad_outputs):
    # grad_outputs is what your module's backward receives as the gradient
    # grad_inputs is the gradient wrt the module's inputs (in this case the
    # input is `a``)
    assert mod is lin
    return (grad_inputs[0] * 100,)

torch.nn.modules.module.register_module_full_backward_hook(hook)

out = lin(a)
out.sum().backward()
print(a.grad)
1 Like

Thank you! Follow up question: how does register_module_full_backward_hook(hook) know which module to be called on? Or will it be called whenever .backward() is called (with the module as an argument), no matter on which module?

Yeah it would be called whenever .backward() is called and it’d be applied to all modules. To figure out which module its being called for you’d look at the first argument yes.

1 Like