Hi, I want replace torch.autograd.functional.jvp/torch.autograd.functional.vjp with torch.func.jvp/torch.func.vjp. Since torch.func.jvp is more efficient compared to autograd ops.
However, I notice these apis are not equivalent, since I calculate jvp to create a loss, thus second-order gradient is required, and we have to set create_graph=True. There is no create_graph argument in torch.func api, so is torch.func does not intent to use in my case or it support higher-order gradient by default?
You need to compose jacrev with jacrev to create a 2nd-order derivative within the torch.func framework. The gradients are computed in a functional manner (so a more direct mathematical approach) within the torch.func namespace.
You can use torch.func.hessian directly,
import torch
from torch.func import hessian, vmap, functional_call
net = Model(*args, **kwargs) #our model
x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)
#functionalize version (params are now an input to the model)
def fcall(parms, x):
return functional_call(net, params, x)
def compute_hessian(params, x):
_hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
return _hessian_wrt_x
#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_hessian, in_dims=(None,0))(params, x)
Or, you can compose jacrev with itself, which you can view as an equivalent of create_graph arg within the torch.autograd API (albeit that the graph is constructed directly).
Here, func is my model and input is the model input. This loss serves as part of the overall loss (loss += isometry_loss) for backpropagation. Therefore, I have to rely on loss.backward() for backpropagation and parameter updating.
I’m actually not want to use the functionalized API, I just want to use fast JVP implementation.
I aim to incorporate isometry loss into my model’s optimization objective to enhance its isometric properties. The computation of isometry loss requires Jacobian-vector products (JVP), and it’s crucial to set create_graph=True to enable gradient backpropagation through the JVP calculation.