def fn(x): return x.pow(2).mean()
loss = fn(input)
grad = torch.autograd.grad(loss, input)
hessian = torch.func.hessian(fn)(input)
Doesn’t hessian then need to re-evaluate the function value and the gradients again even though I have already computed them? My function might be quite expensive. What would be the best solution to calculating hessian?
I believe that you are correct about this. I am not aware of any pre-packaged pytorch hessian functionality that also gives you access to the loss and grad it would have
computed under the hood.
So I do believe that func.hessian() does require duplicative computation in your use
case.
On the other hand, if your grad consists of n components, computing the hessian
(e.g., with func.hessian()) requires n autograd passes (in addition to the first autograd
pass used to compute grad). So the duplicative computation may well be relatively
insignificant.
It is also conceivable that func.hessian() contains some minor internal efficiencies
that are sufficient to overcome the cost of the unnecessary computation.
Note that if your function is expensive, it is also likely the first-derivative backward pass
for grad and the second-derivative backward passes for hessian will be expensive.
So, again, the cost of the n second-derivative backward passes may dominate the
overall cost, with cost of the redundant computation of loss and grad being relatively
minor.
If the cost of hessian is important to your use case, it would probably make sense to
time both approaches.