Hi @rfmiotto,
You can do this quite efficiently within the torch.func
namespace, which requires version torch2.1+. Here’s an example below for the laplacian with respect to an input x
below,
import torch
from torch.func import hessian, vmap, functional_call
net = Model(*args, **kwargs) #our model
x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)
#functionalize version (params are now an input to the model)
def fcall(parms, x):
return functional_call(net, params, x)
def compute_laplacian(params, x):
_hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
_laplacian_wrt_x = _hessian_wrt_x.diagonal(0,-2,-1) #use relative dims for vmap (function doesn't see the batch dim of the input)
return _laplacian_wrt_x
#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_laplacian, in_dims=(None,0))(params, x)