Hi Folks,
Is that feasible to use ∇x∇θL(z, θ) as a regularizer in a classical loss function (e.g. cross entropy)? If so, how to compute it efficiently in Pytorch? Any suggestions would be highly appreciated!
Hi Folks,
Is that feasible to use ∇x∇θL(z, θ) as a regularizer in a classical loss function (e.g. cross entropy)? If so, how to compute it efficiently in Pytorch? Any suggestions would be highly appreciated!
Have a look a the functorch extension, it allows for efficient higher-order gradients
The link to functorch: Install functorch — functorch 1.13 documentation
Thank you @AlphaBetaGamma96 , I find a way to compute 2nd order gradient wrt inputs functorch.grad — functorch 1.13 documentation. Not sure whether it can be used for my aforementioned case: compute 2nd order gradient wrt model parameters and followed by inputs.
You can just wrap different commands together like so,
model = Model()
fnet, params = make_functional(model)
per_sample_gradients = vmap(jacrev(jacrev(fnet, argnums=0), argnums=1), in_dims=(None, 0))(params, x)
Functorch is compositional, so you can just compose different jacrev
calls to get what you want.
This is exactly what I am looking for. Thanks a lot!