Missing argument create_graph in torch.func api

Hi @LuoXin-s,

You need to compose jacrev with jacrev to create a 2nd-order derivative within the torch.func framework. The gradients are computed in a functional manner (so a more direct mathematical approach) within the torch.func namespace.

You can use torch.func.hessian directly,

import torch
from torch.func import hessian, vmap, functional_call

net = Model(*args, **kwargs) #our model

x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)

#functionalize version (params are now an input to the model)
def fcall(parms, x):
  return functional_call(net, params, x)

def compute_hessian(params, x):
  _hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
  return _hessian_wrt_x

#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_hessian, in_dims=(None,0))(params, x) 

Or, you can compose jacrev with itself, which you can view as an equivalent of create_graph arg within the torch.autograd API (albeit that the graph is constructed directly).

def compute_hessian(params, x):
  _hessian_wrt_x = jacrev(jacrev(fcall, argnums=1), argnums=1)(params, x)
  return _hessian_wrt_x

And then use torch.func.vmap to vectorize over a set of samples, like in the first example above.

Here are some other examples on the forums: