I want to iteratively roll out a computation graph while checking the gradients at each step efficiently, ideally within the same time needed to compute the gradients from the full graph.

This is probably best understood via a minimal example. Here is a standard graph

```
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
y[-1].backward()
```

which produces gradients `tensor([1.2677e+30])`

in 0.0022s on my machine (on cpu)

Now, what I want to do is compute the gradients at each step without any hit to performance. Here is a minimal example that achieves that using `retain_graph`

```
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
x.grad.zero_() # introduces extra compute but makes gradients correct
y[i].backward(retain_graph=True)
```

This produces the correct gradients `tensor([1.2677e+30])`

, however, it does so in 0.0618s. My goal is to bring this computing time down as much as possible.

I tried doing `retain_grad()`

on all intermediate variables which indeed retained the gradients at each intermediate variable. However, it looks like `backward()`

isn’t reusing the gradients that it already computed and does the whole backward pass from scratch. Assuming that is correct, my question boils down to how can I make autograd reuse the available gradients in the graph without necessarily computing them again?

Complete script I used for ease of reproducing:

```
from torch import tensor, zeros
from time import time
start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
y[-1].backward()
end = time() - start
print("real y grads in {:.4f}s".format(end), x.grad)
# now get iterative grads
start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
# y[i - 1].retain_grad() # makes this retain the gradients
y[i] += y[i - 1] ** 2
x.grad.zero_() # introduces extra compute but makes gradients correct
y[i].backward(retain_graph=True)
end = time() - start
print("y grads in {:.4f}s".format(end), x.grad)
```