Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix

As the title says, I am trying to compute the Hessian with respect to the weights (not the inputs) of the network. Right now, my code is as follows:

def hessian(loss, params):
      grads = autograd.grad(loss, params, retain_graph=True, create_graph=True)
      flattened_grads = torch.cat(([grad.flatten() for grad in grads]))

      hessian = torch.zeros(flattened_grads.shape[0], flattened_grads.shape[0])

      for idx, grad in enumerate(grads):
           second_der = autograd.grad(grad, params, retain_graph=True, allow_unused=True)
           second_der = torch.cat(([grad.flatten() for grad in second_der]))
           hessian[idx, :] = second_der

That works fine, but it is inefficient because the hessian matrix is symmetric, so the whole matrix does not need to be calculated.

To make it more efficient I thought of the following:

flat_params = torch.cat(([par.flatten() for par in params]))
for idx, grad in enumerate(grads):
     second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True)
     second_der = torch.cat(([grad.flatten() for grad in second_der]))
     hessian[idx, :] = second_der

Namely, I calculate the upper triangular matrix and the lower half can be copied later, but unfortunately second_der = autograd.grad(grad, flat_params[idx::], retain_graph=True, allow_unused=True) throws None. How to fix this? is there I a way can do what I want?

If you’re trying to compute the Hessian of a matrix effeciently, I’d recommend using the functorch package. You can find their repo here, and their documentation on higher-order gradients here.

is functorch much more optimized than autograd.grad? I thought functorch was more to enable per batch analysis, i.e., getting the gradient per sample instead of the sum of all gradients, but it was as fast as autograd.grad.

It’s not just for per-sample gradients (although functorch’s vmap will handle that very easily). I’d recommend giving the documentation a brief read, one of its main features is efficient Hessian calculation. The documentation (with example code) can be found here Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch 0.2.0 documentation