Hook which can modify gradients during backward

Is there a way to achieve the following in pytorch:
As each gradient with require_grad=True is calculated, we apply some function to the gradient. Then these changes are propogated to the rest of the gradient calculations.

For example, if A and B are matrices both with require_grad=True and my model is as follows
output = B @ (A @ x)
and when I put
output.backward()
I want to be able to apply my function to the gradient of B before it is propagated backwards to the gradient of A.
I have tried B.register_hook() and this lets me adjust the gradient of B, but only after the gradient of A is calculated.
Thank you.

Hi @David_Tweedle,

Have a look at the different types of hooks here: Module ā€” PyTorch 2.3 documentation

1 Like

Iā€™d actually expect applying a hook on B via .register_hook() to do what you want. Do you have a snippet to reproduce the issue?

1 Like

Hi there,

@AlphaBetaGamma96 Yes, this is what I was looking for, thank you very much!

@soulitzer I went back and tried the following code, and you are right, torch.Tensor.register_hook does what I want too. Here is a snippet to verify.

import torch
A = torch.tensor([[1.,0],[1.,2.]], requires_grad=True)
B = torch.tensor([1.,2.], requires_grad=True)
x = torch.tensor([1.,2.])
def output(input):
    return B @ (A @ input)
A.grad, B.grad = None, None
loss = output(x)
loss.register_hook(lambda grad: 0 * grad) # gradients of A and B should be zero
loss.backward()
A.grad, B.grad
# (tensor([[0., 0.],
#         [0., 0.]]),
# tensor([0., 0.]))

Thank you both for your replies.

1 Like