Implementing calculation of the Laplacian

Hi @rfmiotto,

You can do this quite efficiently within the torch.func namespace, which requires version torch2.1+. Here’s an example below for the laplacian with respect to an input x below,

import torch
from torch.func import hessian, vmap, functional_call

net = Model(*args, **kwargs) #our model

x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)

#functionalize version (params are now an input to the model)
def fcall(parms, x):
  return functional_call(net, params, x)

def compute_laplacian(params, x):
  _hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
  _laplacian_wrt_x = _hessian_wrt_x.diagonal(0,-2,-1) #use relative dims for vmap (function doesn't see the batch dim of the input)
  return _laplacian_wrt_x

#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_laplacian, in_dims=(None,0))(params, x) 
3 Likes