Computing Hessian for loss function

I’m looking at an implementation for calculating the Hessian matrix of the loss function.

loss = self.loss_function()
grad_params = torch.autograd.grad(loss, p, create_graph=True)  # p is the weight matrix for a particular layer 
hess_params = torch.zeros_like(grad_params[0])

for i in range(grad_params[0].size(0)):
    for j in range(grad_params[0].size(1)):
        hess_params[i, j] = torch.autograd.grad(grad_params[0][i][j], p, retain_graph=True)[0][i, j]

I had 3 questions:

  1. Why do we compute hessian in a loop?Can’t we use something in the lines of
hess_params = torch.autograd.grad(grad_params, p, retain_graph=True)
  1. Current setting takes hours to run when it comes to larger weight matrices. What can I do to enhance the code?

  2. I have seen a hessian function has been implemented in autograd package. How can we use that in this case?

Pointing to reading resources and similar questions would also be highly appreciated.

There’s a solution from Adam which does regular for-loop instead of nested for-loop –

A proposal to provide hessians native in PyTorch by @albanD

You can replace k-backward calls with a single backward call by putting k things in the batch dimension, however, this also increases your memory usage by a factor of k, see this thread – Efficient computation with multiple grad_output's in autograd.grad

There’s a fundamental problem of Hessian being large. Take resnet-50 which has 25M parameters, Hessian then has 625 trillion entries. This means for large networks you have to deal with factorized approximations or consider a subset of the entries like the diagonal, which can be obtained at a similar cost to the gradient.

IE, for ReLU networks, you can get diagonal exactly using Gauss-Newton trick implemented here and for more general networks you can use Hutchison estimator through Hessian-vector products like in this colab

1 Like

Thanks for the detailed response. I’ll look into these. My network has about 460k parameters, so not as large as resnet-50, but definitely not small.

Interestingly enough, the comment at the beginning of this getHessian() function reads

This function computes the diagonal entries of the Hessian matrix of the
decoding NN parameters

But to me, this looks a full Hessian matrix. Doesn’t it? I’ll look into that possibility as well.

@Yaroslav_Bulatov and others helped me formulate more efficient Jacobian calculators:

Essentially you can eliminate for loops (and significantly speed computation time) at the cost of memory. Probably too much memory cost for your large Hessian, but thought I’d add this in case useful.

1 Like