Stop backward() at some intermediate tensor

Hi everyone,

I’m implementing a problem in which I have to calculate gradients with respect to intermediate tensors, use these gradients in further calculations to get a final value and then backpropagate again from this final value. I know it sounds confusing so I made a dummy example for what I’m doing:

import torch

# initialize tensor a and do some dummy operations
a = torch.tensor(2.4, requires_grad=True)
b = a * 7
c = 4 * b ** 2
c.retain_grad()  # I want to use a gradient with respect to this variable

# more dummy operations
d = 2 * torch.log(c) * a
e = d ** 1.2

# backpropagating to get the intermediate grad
f = c.grad  # I got the gradient I wanted f = de/dc

g = e * d * f * f  # using the f in more dummy calculation

# Finally backpropagating to get dg/da

The thing is: when I call e.backward(retain_graph=True) gradients are computed all the way back to the tensor a. In this specific problem, it’s no big deal but in my original problem It takes too much time for unnecessary computing. Is there any way to stop backward() as soon as I have a gradient on the tensor c?

Thank you for reading all this and I hope someone can help me.

I did some research and it seems that

f, = torch.autograd.grad(e, c, retain_graph=True)

does the job! Sorry, I’m new in pytorch.


Note for anybody in the future;

e.backward(inputs=(c,), retain_graph=True)

works for cases where c is a Variable/Parameter and populates c.grad with the gradient output

This (should) let you do things like;

loss.backward(inputs=tuple(model.parameters()), retain_graph=True)

… which lets you backprop a model or two using retrain_graph without autograd doing any unnecessary compute. It looks like autograd will automatically figure out which latents to backprop and skip the rest.