Compute Gramian of Jacobian using func's jvp and/or vjp

Hi,

I am trying to compute the Gramian matrix of the jacobian matrix. That is if for instance my function is

def func(x):
  return A @ x

For some matrix A, then its jacobian is A and the Gramian of its Jacobian is A @ A^T.

Using torch.func.vjp and torch.func.vmap, we can compute the Jacobian as

import torch
from torch.func import vjp, vmap
matrix = torch.randn([4, 5], requires_grad=True)
input = torch.randn([5])
def func(x):
  return matrix @ x
output, vjp_fn = vjp(func, input)
jacobian = vmap(vjp_fn)(torch.eye(matrix.shape[0]))[0]

Now in principle I could then compute jacobian @ jacobian.T but this feels inefficient.

I was then thinking of computing the Gramian using the forward mode func.jvp. The idea would be to say that the Gramian of the jacobian is J @ I @ J.T instead of I @ J @ J.T @ I for the backward mode. So in principle, what I would like to do is to vectorise jvp, give it torch.eye(matrix.shape[1]) and provide the result to jvp again.

My problem is that it seems that jvp (torch.func.jvp — PyTorch 2.2 documentation) doesn’t seem to have the same prototype as vjp and it requires the tangeants as input, i.e. it does not return a function for computing the gradients. This means that it will be less efficient, for instance if I put it in vmap.

Here is my implementation using jvp :

import torch
from torch.func import jvp, vmap
matrix = torch.randn([4, 5], requires_grad=True)
input = torch.randn([5])
def func(x):
    return matrix @ x

output = func(input)
jvp_fn = lambda v: jvp(func, (input,), (v,))[1]
# Here we throw away the first output which is exactly output, this is a bit of a waste
jacobian_map = vmap(jvp_fn)
jacobian = jacobian_map(torch.eye(matrix.shape[1])).T
gramian = jacobian_map(jacobian)

Any idea is most welcome !