Reusing Jacobian and Hessian computational graph

It seems that PyTorch is more or less at the level as JAX for the gradient computation

That’s good news, that needs that your ops are big enough that the creation of the graph is negligible.

Hessian is about 20x slower in PyTorch than in JAX.

That might be due to other reasons, mainly the way the Hessian is computed: jax can use forward mode AD to speed this up while pytorch does not have forward mode AD (yet :wink: ).

1 Like