Calculating the Hessian of loss function wrt torch network parameters

Is there an efficient way to compute second order gradients (at least a partial Hessian) of a loss function with respect to the parameters of a network using PyTorch autograd? How torch.autograd.functional.hessian(func, inputs, ...) works doesn’t play nice at all with Torch modules after all, since a standard loss function does not take the network parameters themselves as inputs, and operates on the network itself.

Two leads I have on this are the following, but neither really solve the problem:

  1. A similar problem is discussed here in the forums: Using `autograd.functional.jacobian`/`hessian` with respect to `nn.Module` parameters. It’s a bit outdated now perhaps, so I don’t know if there are more viable solutions at the moment.

  2. I found a quite cumbersome possible workaround in machine learning - How to compute hessian matrix for all parameters in a network in pytorch? - Stack Overflow, which suggests writing a wrapper function that:
    a) takes in flattenned all network parameters,
    b) unflattens them inside,
    c) and then basically mimicks a forward pass and computes the loss,
    all in order to make it play nice with torch.autograd.functional.hessian(). It could work, but I feel there has to be a better way…

Does anyone have a better way to work around this? Thank you!

2 Likes

Hi,

I am trying to compute Hessian of a loss function w.r.t the parameters of a network using PyTorch autograd. Can you please share your code here or provide me an example how I can do it?

Thanks,
Maryam