Calculating the Hessian of loss function wrt torch network parameters

Is there an efficient way to compute second order gradients (at least a partial Hessian) of a loss function with respect to the parameters of a network using PyTorch autograd? How torch.autograd.functional.hessian(func, inputs, ...) works doesn’t play nice at all with Torch modules after all, since a standard loss function does not take the network parameters themselves as inputs, and operates on the network itself.

Two leads I have on this are the following, but neither really solve the problem:

  1. A similar problem is discussed here in the forums: Using `autograd.functional.jacobian`/`hessian` with respect to `nn.Module` parameters. It’s a bit outdated now perhaps, so I don’t know if there are more viable solutions at the moment.

  2. I found a quite cumbersome possible workaround in machine learning - How to compute hessian matrix for all parameters in a network in pytorch? - Stack Overflow, which suggests writing a wrapper function that:
    a) takes in flattenned all network parameters,
    b) unflattens them inside,
    c) and then basically mimicks a forward pass and computes the loss,
    all in order to make it play nice with torch.autograd.functional.hessian(). It could work, but I feel there has to be a better way…

Does anyone have a better way to work around this? Thank you!



I am trying to compute Hessian of a loss function w.r.t the parameters of a network using PyTorch autograd. Can you please share your code here or provide me an example how I can do it?


On the same boat, haven’t gotten a great way to do it yet. Please share if anyone has. Thanks

Hi @y12uc231, @Mari and @semihcanturk,

If you want to compute the Hessian of a network, you might want to have a look at the functorch library (which now comes packaged with the latest install of PyTorch). A brief example I found on their github page is here:

In general, you can do something like this,

from functorch import make_functional, jacfwd, jacrev, vmap

net = Model() #model instance
fnet, params = make_functional(net) #functorch requires a functional form of the model (fnet) with parameters (params) as an input to the model as well.

per_sample_hessian = vmap(jacfwd(jacrev(fnet, argnums=0), argnums=0), in_dims=(None, 0))(params, x)

where x is our input to the model and we’re calculating the hessian with respect to our network’s parameters via forward-over-reverse Hessian calculation method for efficiency. The example also uses functorch’s vmap to vectorize our calculation over the batch dimension, which is critical in calculating higher-order calculations efficiently!

It’s quite a lot to take in so I recommend you read their notebook on Hessian calculations (among the rest of the functorch documentation) here: Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch 1.13 documentation