Hi,
I am trying to compute the Gramian matrix of the jacobian matrix. That is if for instance my function is
def func(x):
return A @ x
For some matrix A
, then its jacobian is A
and the Gramian of its Jacobian is A @ A^T
.
Using torch.func.vjp
and torch.func.vmap
, we can compute the Jacobian as
import torch
from torch.func import vjp, vmap
matrix = torch.randn([4, 5], requires_grad=True)
input = torch.randn([5])
def func(x):
return matrix @ x
output, vjp_fn = vjp(func, input)
jacobian = vmap(vjp_fn)(torch.eye(matrix.shape[0]))[0]
Now in principle I could then compute jacobian @ jacobian.T
but this feels inefficient.
I was then thinking of computing the Gramian using the forward mode func.jvp
. The idea would be to say that the Gramian of the jacobian is J @ I @ J.T
instead of I @ J @ J.T @ I
for the backward mode. So in principle, what I would like to do is to vectorise jvp, give it torch.eye(matrix.shape[1])
and provide the result to jvp again.
My problem is that it seems that jvp (torch.func.jvp — PyTorch 2.2 documentation) doesn’t seem to have the same prototype as vjp and it requires the tangeants as input, i.e. it does not return a function for computing the gradients. This means that it will be less efficient, for instance if I put it in vmap.
Here is my implementation using jvp :
import torch
from torch.func import jvp, vmap
matrix = torch.randn([4, 5], requires_grad=True)
input = torch.randn([5])
def func(x):
return matrix @ x
output = func(input)
jvp_fn = lambda v: jvp(func, (input,), (v,))[1]
# Here we throw away the first output which is exactly output, this is a bit of a waste
jacobian_map = vmap(jvp_fn)
jacobian = jacobian_map(torch.eye(matrix.shape[1])).T
gramian = jacobian_map(jacobian)
Any idea is most welcome !