How can I calculate the grad in func jacrev()

Hi, everyone. Here is a function to calculate the Neural Tangent Kernel (NTK)

def ntk(model, x):
    from torch.func import functional_call, vmap, jacrev

    params = dict(model.named_parameters())

    def fnet_single_torch(params, x):
        y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
        return y

    jac1 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
    jac1 = jac1.values()
    jac1 = [j.flatten(2) for j in jac1]

    jac2 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
    jac2 = jac2.values()
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == "full":
        einsum_expr = "Naf,Mbf->NMab"
    elif compute == "trace":
        einsum_expr = "Naf,Maf->NM"
    elif compute == "diagonal":
        einsum_expr = "Naf,Maf->NMa"
        assert False

    result = torch.stack(
        [torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]
    result = result.sum(0)
    return result.squeeze()

Now, I need to add a transformation to the network output and calculate the gradient term in this transformation. How can I modify this code?

I have tried adding transformations directly to fnet_single_torch function like below, but I noticed that when using function jacrev(), the incoming x will not require grad

def output_transform(x,y):
        return autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]

def fnet_single_torch(params, x):
        y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
        y = output_transform(x,y)
        return y

Hi @xuelanghanbao,

If you want a function that computes the derivatives of your model output with respect to the inputs within the torch.func namespace, you can just use the torch.func.grad function,

def f(params, x);
  return functional_call(model, params, x.unsqueeze(0)).squeeze(0)

grad_f = torch.func.grad(f, argnums=(0))(params, x)

If your input has a batch dimension you can compose the gradient within a vmap, because otherwise you’ll compute the gradient across different samples in your batch (which isn’t what you want),

grad_f = torch.func.vmap(torch.func.grad(f, argnums=(0)), in_dims=(None,0))(params, x)