Is there an efficient way to compute second order gradients (at least a partial Hessian) of a loss function with respect to the parameters of a network using PyTorch autograd? How torch.autograd.functional.hessian(func, inputs, ...)
works doesn’t play nice at all with Torch modules after all, since a standard loss function does not take the network parameters themselves as inputs, and operates on the network itself.
Two leads I have on this are the following, but neither really solve the problem:
-
A similar problem is discussed here in the forums: Using `autograd.functional.jacobian`/`hessian` with respect to `nn.Module` parameters. It’s a bit outdated now perhaps, so I don’t know if there are more viable solutions at the moment.
-
I found a quite cumbersome possible workaround in machine learning - How to compute hessian matrix for all parameters in a network in pytorch? - Stack Overflow, which suggests writing a wrapper function that:
a) takes in flattenned all network parameters,
b) unflattens them inside,
c) and then basically mimicks a forward pass and computes the loss,
all in order to make it play nice withtorch.autograd.functional.hessian()
. It could work, but I feel there has to be a better way…
Does anyone have a better way to work around this? Thank you!