How to compute second order gradient

Hi Folks,

Is that feasible to use ∇x∇θL(z, θ) as a regularizer in a classical loss function (e.g. cross entropy)? If so, how to compute it efficiently in Pytorch? Any suggestions would be highly appreciated!

Have a look a the functorch extension, it allows for efficient higher-order gradients

The link to functorch: Install functorch — functorch 1.13 documentation

Thank you @AlphaBetaGamma96 , I find a way to compute 2nd order gradient wrt inputs functorch.grad — functorch 1.13 documentation. Not sure whether it can be used for my aforementioned case: compute 2nd order gradient wrt model parameters and followed by inputs.

You can just wrap different commands together like so,

model = Model()

fnet, params = make_functional(model)

per_sample_gradients = vmap(jacrev(jacrev(fnet, argnums=0), argnums=1), in_dims=(None, 0))(params, x)

Functorch is compositional, so you can just compose different jacrev calls to get what you want.

This is exactly what I am looking for. Thanks a lot!