Missing argument create_graph in torch.func api

Hi, I want replace torch.autograd.functional.jvp/torch.autograd.functional.vjp with torch.func.jvp/torch.func.vjp. Since torch.func.jvp is more efficient compared to autograd ops.

However, I notice these apis are not equivalent, since I calculate jvp to create a loss, thus second-order gradient is required, and we have to set create_graph=True. There is no create_graph argument in torch.func api, so is torch.func does not intent to use in my case or it support higher-order gradient by default?

Hi @LuoXin-s,

You need to compose jacrev with jacrev to create a 2nd-order derivative within the torch.func framework. The gradients are computed in a functional manner (so a more direct mathematical approach) within the torch.func namespace.

You can use torch.func.hessian directly,

import torch
from torch.func import hessian, vmap, functional_call

net = Model(*args, **kwargs) #our model

x = torch.randn(batch_size, num_inputs) #inputs
params = dict(net.named_parameters()) #params (need to be an input for torch.func to work)

#functionalize version (params are now an input to the model)
def fcall(parms, x):
  return functional_call(net, params, x)

def compute_hessian(params, x):
  _hessian_wrt_x = hessian(fcall, argnums=1)(params, x) #forward-over-reverse hessian calc.
  return _hessian_wrt_x

#We define a laplacian func for a single function, then vectorize over the batch.
laplacian = vmap(compute_hessian, in_dims=(None,0))(params, x) 

Or, you can compose jacrev with itself, which you can view as an equivalent of create_graph arg within the torch.autograd API (albeit that the graph is constructed directly).

def compute_hessian(params, x):
  _hessian_wrt_x = jacrev(jacrev(fcall, argnums=1), argnums=1)(params, x)
  return _hessian_wrt_x

And then use torch.func.vmap to vectorize over a set of samples, like in the first example above.

Here are some other examples on the forums:

Thank you for your suggestion. However, I’m still confused. Let me explain further.

Actually, I want to calculate the following loss:

def isometry_loss(self, input, func):
    u = torch.randn_like(input, device=input.device)
    Ju = torch.autograd.functional.jvp(func, input, u, create_graph=True)[1]
    JTJu = torch.autograd.functional.vjp(func, input, Ju, create_graph=True)[1]
    
    TrR = torch.sum(Ju.float() ** 2, dim=tuple(range(1, Ju.dim()))).mean()
    TrR2 = torch.sum(JTJu.float() ** 2, dim=tuple(range(1, JTJu.dim()))).mean()
    
    isometry_loss = TrR2 / TrR ** 2
    return isometry_loss

Here, func is my model and input is the model input. This loss serves as part of the overall loss (loss += isometry_loss) for backpropagation. Therefore, I have to rely on loss.backward() for backpropagation and parameter updating.

I’m actually not want to use the functionalized API, I just want to use fast JVP implementation.

I don’t how to convert it to the isometry loss in my case

I made some simplifications, and I found that the following losses have different gradients with respect to the network parameters:

def iso_loss1():
    Ju = torch.autograd.functional.jvp(func, input, u, create_graph=True)[1]
    TrR = torch.sum(Ju.float() ** 2, dim=tuple(range(1, Ju.dim()))).mean()

    isometry_loss = TrR
    return isometry_loss

def iso_loss2():
    Ju = torch.func.jvp(func, (input,), (u,))[1]
    TrR = torch.sum(Ju.float() ** 2, dim=tuple(range(1, Ju.dim()))).mean()

    isometry_loss = TrR
    return isometry_loss

Why do you keep create_graph==True in the iso_loss1?

Can you share what the output of the two losses are?

I aim to incorporate isometry loss into my model’s optimization objective to enhance its isometric properties. The computation of isometry loss requires Jacobian-vector products (JVP), and it’s crucial to set create_graph=True to enable gradient backpropagation through the JVP calculation.