Calculate batched jvp (R-op)

One user posted his code for an R-op here, but with no explanation of how to use it. Using the approach described in this blog I can calculate a Jacobian-vector product for a single vector. What I’d like to be able to do is to calculate it for an entire batch of vectors (some matrix). Here is the code I’m using for the single vector case:

mlp = MLP()
batch_size = 1
x = torch.randn(batch_size, 25)
v = torch.randn(batch_size, mlp.n_params())
y = mlp(x)
u = torch.ones_like(y, requires_grad=True)

vjp = torch.autograd.grad(y, mlp.parameters(), grad_outputs=u, create_graph=True)
vjp_flat =[grad.view(-1) for grad in vjp])
jvp = torch.autograd.grad(vjp_flat, u, grad_outputs=v, retain_graph=False)

I’ve tried making u and v matrices (making the first dimension larger than 1) and the results are incorrect. I’m validating the jvp by checking it against a Jacobian obtained by sequentially calling backward on every element of the network output.

Does anyone know how I can accomplish this? I run into needing to compute this quantity quite a bit.

My code here might help. Since you are dealing with output-params jacobian, maybe you need to be more careful with dimensions.

Thanks, that seems to work pretty well. It seems like it’s possible to get a batch of jvps where there are multiple Jacobians, but only one vector, but when the vector is batched then the result is a batch of sums. Is that correct?

So, supposing that we have 3 vectors, this code will calculate a batch of sums of the 3 products?

I think you are right. In your case where the v is collection of multiple vectors, your might need to tweak the components of ujp to get the correct dimension. Or you can use get_jvp for the weight of each module and then combine them together.