I’m implementing a problem in which I have to calculate gradients with respect to intermediate tensors, use these gradients in further calculations to get a final value and then backpropagate again from this final value. I know it sounds confusing so I made a dummy example for what I’m doing:
import torch # initialize tensor a and do some dummy operations a = torch.tensor(2.4, requires_grad=True) b = a * 7 c = 4 * b ** 2 c.retain_grad() # I want to use a gradient with respect to this variable # more dummy operations d = 2 * torch.log(c) * a e = d ** 1.2 # backpropagating to get the intermediate grad e.backward(retain_graph=True) f = c.grad # I got the gradient I wanted f = de/dc print(f) g = e * d * f * f # using the f in more dummy calculation # Finally backpropagating to get dg/da g.backward() print(a.grad)
The thing is: when I call e.backward(retain_graph=True) gradients are computed all the way back to the tensor a. In this specific problem, it’s no big deal but in my original problem It takes too much time for unnecessary computing. Is there any way to stop backward() as soon as I have a gradient on the tensor c?
Thank you for reading all this and I hope someone can help me.