Is there an efficient way to compute second order gradients (at least a partial Hessian) of a loss function with respect to the parameters of a network using PyTorch autograd? How torch.autograd.functional.hessian(func, inputs, ...) works doesn’t play nice at all with Torch modules after all, since a standard loss function does not take the network parameters themselves as inputs, and operates on the network itself.
Two leads I have on this are the following, but neither really solve the problem:
I found a quite cumbersome possible workaround in machine learning - How to compute hessian matrix for all parameters in a network in pytorch? - Stack Overflow, which suggests writing a wrapper function that:
a) takes in flattenned all network parameters,
b) unflattens them inside,
c) and then basically mimicks a forward pass and computes the loss,
all in order to make it play nice with torch.autograd.functional.hessian(). It could work, but I feel there has to be a better way…
Does anyone have a better way to work around this? Thank you!
I am trying to compute Hessian of a loss function w.r.t the parameters of a network using PyTorch autograd. Can you please share your code here or provide me an example how I can do it?
If you want to compute the Hessian of a network, you might want to have a look at the functorch library (which now comes packaged with the latest install of PyTorch). A brief example I found on their github page is here: https://github.com/pytorch/functorch/issues/989.
In general, you can do something like this,
from functorch import make_functional, jacfwd, jacrev, vmap
net = Model() #model instance
fnet, params = make_functional(net) #functorch requires a functional form of the model (fnet) with parameters (params) as an input to the model as well.
per_sample_hessian = vmap(jacfwd(jacrev(fnet, argnums=0), argnums=0), in_dims=(None, 0))(params, x)
where x is our input to the model and we’re calculating the hessian with respect to our network’s parameters via forward-over-reverse Hessian calculation method for efficiency. The example also uses functorch’s vmap to vectorize our calculation over the batch dimension, which is critical in calculating higher-order calculations efficiently!