I want to iteratively roll out a computation graph while checking the gradients at each step efficiently, ideally within the same time needed to compute the gradients from the full graph.
This is probably best understood via a minimal example. Here is a standard graph
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
y[-1].backward()
which produces gradients tensor([1.2677e+30])
in 0.0022s on my machine (on cpu)
Now, what I want to do is compute the gradients at each step without any hit to performance. Here is a minimal example that achieves that using retain_graph
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
x.grad.zero_() # introduces extra compute but makes gradients correct
y[i].backward(retain_graph=True)
This produces the correct gradients tensor([1.2677e+30])
, however, it does so in 0.0618s. My goal is to bring this computing time down as much as possible.
I tried doing retain_grad()
on all intermediate variables which indeed retained the gradients at each intermediate variable. However, it looks like backward()
isn’t reusing the gradients that it already computed and does the whole backward pass from scratch. Assuming that is correct, my question boils down to how can I make autograd reuse the available gradients in the graph without necessarily computing them again?
Complete script I used for ease of reproducing:
from torch import tensor, zeros
from time import time
start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
for i in range(1, len(y)):
y[i] += y[i - 1] ** 2
y[-1].backward()
end = time() - start
print("real y grads in {:.4f}s".format(end), x.grad)
# now get iterative grads
start = time()
x = tensor([1.0], requires_grad=True)
y = [zeros(1) for _ in range(100)]
y[0] += x**2
y[0].backward(retain_graph=True)
for i in range(1, len(y)):
# y[i - 1].retain_grad() # makes this retain the gradients
y[i] += y[i - 1] ** 2
x.grad.zero_() # introduces extra compute but makes gradients correct
y[i].backward(retain_graph=True)
end = time() - start
print("y grads in {:.4f}s".format(end), x.grad)