Backwards() through a graph iteratively and efficiently

I want to iteratively roll out a computation graph while checking the gradients at each step efficiently, ideally within the same time needed to compute the gradients from the full graph.

This is probably best understood via a minimal example. Here is a standard graph

x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
    y[i] += y[i - 1] ** 2
y[-1].backward()

which produces gradients tensor([1.2677e+30]) in 0.0022s on my machine (on cpu)

Now, what I want to do is compute the gradients at each step without any hit to performance. Here is a minimal example that achieves that using retain_graph

x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
    y[i] += y[i - 1] ** 2
    x.grad.zero_()  # introduces extra compute but makes gradients correct
    y[i].backward(retain_graph=True)

This produces the correct gradients tensor([1.2677e+30]), however, it does so in 0.0618s. My goal is to bring this computing time down as much as possible.

I tried doing retain_grad() on all intermediate variables which indeed retained the gradients at each intermediate variable. However, it looks like backward() isn’t reusing the gradients that it already computed and does the whole backward pass from scratch. Assuming that is correct, my question boils down to how can I make autograd reuse the available gradients in the graph without necessarily computing them again?

Complete script I used for ease of reproducing:

from torch import tensor, zeros
from time import time

start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
    y[i] += y[i - 1] ** 2
y[-1].backward()
end = time() - start
print("real y grads in {:.4f}s".format(end), x.grad)

# now get iterative grads
start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
    # y[i - 1].retain_grad()  # makes this retain the gradients
    y[i] += y[i - 1] ** 2
    x.grad.zero_()  # introduces extra compute but makes gradients correct
    y[i].backward(retain_graph=True)
end = time() - start
print("y grads in {:.4f}s".format(end), x.grad)

A dumb way do this is to manually chain the gradients together, if y_2 = fn(y_1) you could do dy_1 = torch.autograd.grad(fn, y_1, grad_output=(dy_2,))

Another way is to cache the outputs of operations that happen during the backward pass, I don’t think that is better though because (due to a bug) it probably breaks when you have inplace ops in backward formulas. It would also save a bunch of tensors unnecessarily.