Generate Hessian diagnal matrix for every sample in parallel

Hi, I noticed that there are some repositories that allow us to get gradients of each sample in a batch.
I wonder if there is any way that I could get the hessian diagonal matrix of each sample in every batch?

If there’s no particular structure to your function (element-wise), you might as well compute the whole hessian. And if you want that for every sample, you’d do vmap(hessian(fn))(...).

You can do that with functorch — functorch 2.0 documentation (JAX-like API in PyTorch).

1 Like

Thank you, I have found the solution of my question.
There is a reposity that has a function to get approximation hessian matrix of every sample in the batch.
https://backpack.pt/