Reusing Jacobian and Hessian computational graph

Hello,

I’m using PyTorch as an audodiff tool to compute the first and second derivatives of a cost function to be used in a (non-deep-learning) optimization tool (ipopt). The cost function depends about 10 parameters. I am using the new torch.autograd.functional.jacobian and torch.autograd.functional.hessian added to PyTorch 1.5. The Jacobian and Hessian get called several times (about 100) with different input parameters until the function is minimized.

To be more specific, I am computing the Jacobians and Hessians as:

def jac(f, x):
    x_tensor = torch.tensor(x, requires_grad=True, dtype=torch.float64)   # Convert input variable to torch tensor
    jac = torch.autograd.functional.jacobian(f, x_tensor)
    return jac.numpy()

def hess(f, x):
    x_tensor = torch.tensor(x, requires_grad=True, dtype=torch.float64)   # Convert input variable to torch tensor
    hess = torch.autograd.functional.hessian(f, x_tensor)
    return hess.numpy()

Since the functions to compute the derivatives are always the same, I was wondering if there was a way (or if it makes sense) to save the computational graph to avoid computing many times the same quantity, and thus speed up the calculation.

Thanks,

Daniel

Hi,

The short answer is:
Pytorch is actually built in such a way that the overhead of the graph creation is low enough (compared to all the computations you do in the forward) that you can do it at every iteration.

The long answer is:
This is true for neural network and when each op is quite “large” but unfortunately, if you have only very small ops, then the overhead might become noticeable.

The usual answer is: “use the jit to get the last drop of perf for your production evaluation”. But since you want to actually use backward here, you might run into troubles.
You can try to jit your function f. And I would be curious to know if it leads to any improvement.

1 Like

Thanks for your answer.

I have tried to use jit to speed up the computation. It doesn’t work. I have been able to jit the original function, but this does not lead to a relevant speedup. I can’t jit the gradients and the Hessian:

In the same gist I compare also to JAX for gradient and hessian. With JAX it is possible to precompute the gradients using jit. It seems that PyTorch is more or less at the level as JAX for the gradient computation, but not for Hessian computation: Hessian is about 20x slower in PyTorch than in JAX.

It seems that PyTorch is more or less at the level as JAX for the gradient computation

That’s good news, that needs that your ops are big enough that the creation of the graph is negligible.

Hessian is about 20x slower in PyTorch than in JAX.

That might be due to other reasons, mainly the way the Hessian is computed: jax can use forward mode AD to speed this up while pytorch does not have forward mode AD (yet :wink: ).

1 Like

Consider the example case of computing determinants using expansion by minors. (I know this is not an efficient way to compute determinants, but it demonstrates a case that has lots of operations on scalar values).

I am trying to re-use the graph for this case and my method of doing so works for 2 by 2 matrices and fails for 3 by 3 matrices (see the pythong programs below) and do not know why ?

det_22.py

# Test reusing graph for derivatives of determinant of 2 by 2 matrix.
# The output generated by this program is below:
#
# First gradient passed check.
# Second gradient passed check.
#
# imports
import sys
import torch
import numpy
#
# check_grad
def check_grad(ax) :
   # ok, eps99
   ok    = True
   eps99 = 99.0 * numpy.finfo(float).eps
   #
   # ok
   check = ax.data[1,1]
   if abs( ax.grad[0,0] - check ) > eps99 :
      ok = False
      print( f'ax.grad[0,0] = {ax.grad[0,0]}, check = {check}' )
   #
   # ok
   check = -ax.data[1,0]
   if abs( ax.grad[0,1] - check ) > eps99 :
      ok = False
      print( f'ax.grad[0,1] = {ax.grad[0,1]}, check = {check}' )
   #
   # ok
   check = -ax.data[0,1]
   if abs( ax.grad[1,0] - check ) > eps99 :
      ok = False
      print( f'ax.grad[1,0] = {ax.grad[1,0]}, check = {check}' )
   #
   # ok
   check = ax.data[0,0]
   if abs( ax.grad[1,1] - check ) > eps99 :
      ok = False
      print( f'ax.grad[1,1] = {ax.grad[1,1]}, check = {check}' )
   #
   return ok
#
# main
def main() :
   #
   # ok
   ok = True
   #
   # n
   n = 2
   #
   # ax
   x  = numpy.random.uniform(0.0, 1.0, (n , n) )
   ax = torch.tensor(x, requires_grad = True)
   #
   # az
   az  = ax[0,0] * ax[1,1] - ax[0,1] * ax[1,0]
   #
   # ax.grad
   az.backward(retain_graph = True)
   #
   # check_grad
   if check_grad(ax) :
      print( 'First gradient passed check.' )
   else :
      print( 'First gradient failed check.' )
   #
   #
   # ax.data
   x  = numpy.random.uniform(0.0, 1.0, (n, n) )
   for i in range(n) :
      for j in range(n) :
         ax.data[i,j] = x[i,j]
   #
   # ax.grad
   ax.grad.zero_()
   az.backward(retain_graph = True)
   #
   # check_grad
   if check_grad(ax) :
      print( 'Second gradient passed check.' )
   else :
      print( 'Second gradient failed check.' )
#
main()

det_33.py:

# Test reusing graph for derivatives of determinant of 3 by 3 matrix.
# The output generated by this program is below. The actual numbers
# in the output will vary becasue a different random matrix is chosen
# for each evaluation.
#
# First gradient passed check.
# ax.grad[0,0] = 0.07585514040844837, check = 0.4295608074373773
# ax.grad[0,1] = -0.6133183512861293, check = -0.11782369019260797
# ax.grad[0,2] = 0.5337097801031835, check = 0.040633019648616306
# Second gradient failed check.
#
#
# imports
import torch
import numpy
#
# check_grad
def check_grad(ax) :
   # ok, eps99
   ok    = True
   eps99 = 99.0 * numpy.finfo(float).eps
   #
   # ok
   check = ( ax[1,1] * ax[2,2] - ax[1,2] * ax[2,1] )
   if abs( ax.grad[0,0] - check ) > eps99 :
      ok = False
      print( f'ax.grad[0,0] = {ax.grad[0,0]}, check = {check}' )
   #
   # ok
   check = - ( ax[1,0] * ax[2,2] - ax[1,2] * ax[2,0] )
   if abs( ax.grad[0,1] - check ) > eps99 :
      ok = False
      print( f'ax.grad[0,1] = {ax.grad[0,1]}, check = {check}' )
   #
   # ok
   check = ( ax[1,0] * ax[2,1] - ax[1,1] * ax[2,0] )
   if abs( ax.grad[0,2] - check ) > eps99 :
      ok = False
      print( f'ax.grad[0,2] = {ax.grad[0,2]}, check = {check}' )
   #
   return ok
#
# main
def main() :
   #
   # ok
   ok = True
   #
   # n
   n = 3
   #
   # ax
   x  = numpy.random.uniform(0.0, 1.0, (n, n))
   ax = torch.tensor(x, requires_grad = True)
   #
   #  ax[0,0]  ax[0,1]  ax[0,2]
   #  ax[1,0]  ax[1,1]  ax[1,2]
   #  ax[2,0]  ax[2,1]  ax[2,2]
   #
   # az
   az  = ax[0,0] * ( ax[1,1] * ax[2,2] - ax[1,2] * ax[2,1] )
   az -= ax[0,1] * ( ax[1,0] * ax[2,2] - ax[1,2] * ax[2,0] )
   az += ax[0,2] * ( ax[1,0] * ax[2,1] - ax[1,1] * ax[2,0] )
   #
   #
   # ax.grad
   az.backward(retain_graph = True)
   #
   # check_grad
   if check_grad(ax) :
      print( 'First gradient passed check.' )
   else :
      print( 'First gradient failed check.' )
   #
   #
   # ax.data
   x  = numpy.random.uniform(0.0, 1.0, (n, n) )
   for i in range(n) :
      for j in range(n) :
         ax.data[i,j] = x[i,j]
   #
   # ax.grad
   ax.grad.zero_()
   az.backward(retain_graph = True)
   #
   # check_grad
   if check_grad(ax) :
      print( 'Second gradient passed check.' )
   else :
      print( 'Second gradient failed check.' )
#
main()

Hi! :slightly_smiling_face:

I’m using torch.autograd.functional.jacobian to include the gradient of the outputs wrt the inputs as a loss for the training of a NN.

However, once the gradient is computed and added to the loss, its graph is freed. As a result loss.backward() cannot flow inside the gradient-based loss.

Is there a solution for that, e.g. setting in any way retain_graph=True for torch.autograd.functional.jacobian ?

Thanks in advance for helping!

Hey
You’re looking for the “create_graph=True” flag I think