The reason why it isn’t working with torch.func.vmap
is that torch.func.vmap
requires the entire process be within its ‘funtionalized’ approach, i.e. you can’t mix torch.autograd
operations with torch.func
when computing higher derivatives.
You can look at a previous answer I’ve shared on the forums, which focuses on using torch.func
to compute the Hessian, here: Efficient computation of Hessian with respect to network weights using autograd.grad and symmetry of Hessian matrix - #8 by AlphaBetaGamma96