# What's the best way to call backward() several times?

I’m writing code (specifically an implementation of Hamiltonian Monte Carlo) that calls backward() on a function several times at each iteration. Profiling this code shows that almost the entire run time is devoted to repeated calls to ‘backward()’.

As this function never changes, is there a way to store the gradient itself as a function in order to reduce the calls to backward, or is there a way to speed up these repeated calls to backward()?

A very minimal example of the code I’m writing is shown below:

``````import torch

y = Variable(torch.Tensor([0.1, 0.1]), requires_grad = True)

def myenergy(q):
sigmainv = Variable(torch.Tensor([[10.25, -9.74], [-9.74, 10.25]]))
#corresponds approximately to rho = 0.95
return 0.5 * q.matmul(sigmainv).dot(q)

def HMC_basic(pos, energy, T = 10000, n_steps = 10, stepsize = 0.25):
#not quite a correct implementation of HMC
for t in range(T):
vel = Variable(torch.randn(pos.size()))

for i in range(n_steps):
if i is not n_steps - 1:
energy(pos).backward()

return pos

out = HMC_basic(y, myenergy)

#print(prof.key_averages())
``````

Running this code on a CPU and profiling with cProfile and snakeviz tells me that calls to backward took up ~60% (15s) of the total run time of this code (~20s).

In this example, calculating the gradient explicitly is simple, but in most cases it’s not.

If you know the gradient doesn’t change, you could compute it once, store it in a variable, and reuse it elsewhere:

``````grad = energy(pos).backward()
``````

If you wanted to memoize the calls to backward, one thing you could do is write a function that computes the gradient you want, something like:

``````prev_computations = {}  # maps pos -> gradient
def energy_backwards(pos):
if pos in prev_computations:
return prev_computations[pos]
#  Otherwise, compute the gradient analytically
...