Is there an efficient way to compute second order gradients (at least a partial Hessian) of a loss function with respect to the parameters of a network using PyTorch autograd? How
torch.autograd.functional.hessian(func, inputs, ...) works doesn’t play nice at all with Torch modules after all, since a standard loss function does not take the network parameters themselves as inputs, and operates on the network itself.
Two leads I have on this are the following, but neither really solve the problem:
A similar problem is discussed here in the forums: Using `autograd.functional.jacobian`/`hessian` with respect to `nn.Module` parameters. It’s a bit outdated now perhaps, so I don’t know if there are more viable solutions at the moment.
I found a quite cumbersome possible workaround in machine learning - How to compute hessian matrix for all parameters in a network in pytorch? - Stack Overflow, which suggests writing a wrapper function that:
a) takes in flattenned all network parameters,
b) unflattens them inside,
c) and then basically mimicks a forward pass and computes the loss,
all in order to make it play nice with
torch.autograd.functional.hessian(). It could work, but I feel there has to be a better way…
Does anyone have a better way to work around this? Thank you!