Hi @LuoXin-s,
You need to compose jacrev
with jacrev
to create a 2nd-order derivative within the torch.func
framework. The gradients are computed in a functional manner (so a more direct mathematical approach) within the torch.func
namespace.
You can use torch.func.hessian
directly,
import torch
from torch.func import hessian, vmap, functional_call
net = Model(*args, **kwargs) #our model
x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)
#functionalize version (params are now an input to the model)
def fcall(parms, x):
return functional_call(net, params, x)
def compute_hessian(params, x):
_hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
return _hessian_wrt_x
#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_hessian, in_dims=(None,0))(params, x)
Or, you can compose jacrev
with itself, which you can view as an equivalent of create_graph
arg within the torch.autograd
API (albeit that the graph is constructed directly).
def compute_hessian(params, x):
_hessian_wrt_x = jacrev(jacrev(fcall, argnums=1), argnums=1)(params, x)
return _hessian_wrt_x
And then use torch.func.vmap
to vectorize over a set of samples, like in the first example above.
Here are some other examples on the forums: