I’m using PyTorch as an audodiff tool to compute the first and second derivatives of a cost function to be used in a (non-deep-learning) optimization tool (ipopt). The cost function depends about 10 parameters. I am using the new torch.autograd.functional.jacobian and torch.autograd.functional.hessian added to PyTorch 1.5. The Jacobian and Hessian get called several times (about 100) with different input parameters until the function is minimized.

To be more specific, I am computing the Jacobians and Hessians as:

Since the functions to compute the derivatives are always the same, I was wondering if there was a way (or if it makes sense) to save the computational graph to avoid computing many times the same quantity, and thus speed up the calculation.

The short answer is:
Pytorch is actually built in such a way that the overhead of the graph creation is low enough (compared to all the computations you do in the forward) that you can do it at every iteration.

The long answer is:
This is true for neural network and when each op is quite “large” but unfortunately, if you have only very small ops, then the overhead might become noticeable.

The usual answer is: “use the jit to get the last drop of perf for your production evaluation”. But since you want to actually use backward here, you might run into troubles.
You can try to jit your function f. And I would be curious to know if it leads to any improvement.

I have tried to use jit to speed up the computation. It doesn’t work. I have been able to jit the original function, but this does not lead to a relevant speedup. I can’t jit the gradients and the Hessian:

In the same gist I compare also to JAX for gradient and hessian. With JAX it is possible to precompute the gradients using jit. It seems that PyTorch is more or less at the level as JAX for the gradient computation, but not for Hessian computation: Hessian is about 20x slower in PyTorch than in JAX.

It seems that PyTorch is more or less at the level as JAX for the gradient computation

That’s good news, that needs that your ops are big enough that the creation of the graph is negligible.

Hessian is about 20x slower in PyTorch than in JAX.

That might be due to other reasons, mainly the way the Hessian is computed: jax can use forward mode AD to speed this up while pytorch does not have forward mode AD (yet ).