Is there a way to achieve the following in pytorch:
As each gradient with require_grad=True is calculated, we apply some function to the gradient. Then these changes are propogated to the rest of the gradient calculations.
For example, if A and B are matrices both with require_grad=True and my model is as follows
output = B @ (A @ x)
and when I put
output.backward()
I want to be able to apply my function to the gradient of B before it is propagated backwards to the gradient of A.
I have tried B.register_hook() and this lets me adjust the gradient of B, but only after the gradient of A is calculated.
Thank you.
Hi @David_Tweedle,
Have a look at the different types of hooks here: Module ā PyTorch 2.3 documentation
1 Like
Iād actually expect applying a hook on B via .register_hook() to do what you want. Do you have a snippet to reproduce the issue?
1 Like
Hi there,
@AlphaBetaGamma96 Yes, this is what I was looking for, thank you very much!
@soulitzer I went back and tried the following code, and you are right, torch.Tensor.register_hook does what I want too. Here is a snippet to verify.
import torch
A = torch.tensor([[1.,0],[1.,2.]], requires_grad=True)
B = torch.tensor([1.,2.], requires_grad=True)
x = torch.tensor([1.,2.])
def output(input):
return B @ (A @ input)
A.grad, B.grad = None, None
loss = output(x)
loss.register_hook(lambda grad: 0 * grad) # gradients of A and B should be zero
loss.backward()
A.grad, B.grad
# (tensor([[0., 0.],
# [0., 0.]]),
# tensor([0., 0.]))
Thank you both for your replies.
1 Like