How to calculate Jacobians for a batch

Hello, I want to calculate Jacobian matrices for a batch of data.

I have x (batch_size, 3) and calculated y (batch_size, 3). I need a Jacobian matrix of shape (batch_size, 3, 3)

I tried the following code:

import torch

x = torch.randn(1024, 3, requires_grad=True) # a batch of coordinates
y = torch.sin(x) # calculated y in forward pass

grads = []
for i_dim in range(y.shape[1]):
    y_i = y[:, i_dim]
    ones = torch.ones_like(y_i)
    # `torch.autograd.grad` seems to implicitly add dimensions up.
    # So I separately calculate grad for every dim of `y`
    # But this still fails.
    grad = torch.autograd.grad(y_i, x, grad_outputs=ones, create_graph=True, retain_graph=True, is_grads_batched=True)[0]
    grads.append(grad)

grad = torch.cat(grad, dim=1)

print(grad.shape)
RuntimeError: If `is_grads_batched=True`, 
we interpret the first dimension of each grad_output as the batch dimension. 
The sizes of the remaining dimensions are expected to match the shape of corresponding output, 
but a mismatch was detected: grad_output[0] has a shape of torch.Size([]) and output[0] has a shape of torch.Size([1024]). 
If you only want some tensors in `grad_output` to be considered batched, consider using vmap.

I think outputs and grad_output have already the same shape. Why is this error raised?

BTW, I tried torch.func.jacrev(), but I need to calculate y beforehand, not in the closure of jacrev(). My code has been written this way.

Hi Windows!

I don’t think that grad_output works quite the way you are expecting.

Here is a script, loosely based on the code you posted, that I think does what you want
and also illustrates how you can use vmap / jacrev to compute a batch of jacobians:

import torch
print (torch.__version__)

_ = torch.manual_seed (2025)

x = torch.randn (1024, 3, requires_grad = True)   # a batch of coordinates

# we use the fact that y[j] is independent of x[i] unless j == i
# (but we don't use that y[j, l] is independent of x[i, k] unless also l == k)
y = torch.sin (x)                                 # calculated y in forward pass

ysum = y.sum (dim = 0)                            # use independence of y[j] on x[i]

# is_grad_batched is batching over ysum's single dimension of length 3
grad_outputs = torch.eye (y.shape[1])
batch_grad = torch.autograd.grad (ysum, x, grad_outputs = grad_outputs, is_grads_batched = True)[0]

batch_grad = batch_grad.transpose (0, 1)          # move batch_size = 1024 to the front

print ('batch_grad.shape =', batch_grad.shape)
m = 77
print ('m =', m, ', batch_grad[m] = ...')         # observe that any batch_grad[m] is diagonal
print (batch_grad[m])


# use vmap and jacrev to compute jacobian (as batch_grad_b)

batch_jacobian = torch.vmap (torch.func.jacrev (torch.sin))
batch_grad_b = batch_jacobian (x)

print ('batch_grad_b.shape =', batch_grad_b.shape)
print ('torch.equal (batch_grad, batch_grad_b) =', torch.equal (batch_grad, batch_grad_b)) 

And here is its output:

2.6.0+cu126
batch_grad.shape = torch.Size([1024, 3, 3])
m = 77 , batch_grad[m] = ...
tensor([[0.4986, 0.0000, 0.0000],
        [0.0000, 0.9996, 0.0000],
        [0.0000, 0.0000, 0.8100]])
batch_grad_b.shape = torch.Size([1024, 3, 3])
torch.equal (batch_grad, batch_grad_b) = True

It’s not clear to me which of the two approaches illustrated will run faster. The first
approach uses a single explicit forward pass and, under the hood, three backward
passes that I think are vmaped. The second approach, however, runs 1024 vmaped
forward passes, so you might expect that it would be significantly slower.

You don’t explain why you need to calculate y beforehand, but pytorch provides
functional_call() that makes it easy to run a model as a closure.

Best.

K. Frank

1 Like

Thank you very much. Now I understand when samples in a batch are independent, summing y over batch dim is OK.

Then another thought comes to me: what if the samples in a batch are not independent? I have used attention in model, so the sample coordinates are not independent. Can I still compute a batch of Jacobians like (batch_size, y_dim, x_dim), not (batch_size, y_dim, batch_size, x_dim), using torch.autograd.grad()?

I’m not sure how to use functional_call() yet, but I can explain the problem I encounter when trying jacrev().

In real code it is a PINN problem, what I send to model() is not only a tensor x, but a tuple of a bunch of things, like another tensor and multiple masks. And model’s output is also not a single tensor.

My code is like out1, out2 = model(x1, x2, mask1, mask2) and then compute Jacobian of out1 with respect to x1. After that out1 and out2 are still needed for other computation.

Hi Windows!

Yes, but because the “samples aren’t independent,” you can’t sum over the batch
dimension of your output y. You will have to have autograd.grad() perform 3072
backward passes using is_grads_batched = True. Also, technically speaking, you
will not be keeping full jacobian matrices, but rather selecting out just the desired
submatrices of the jacobian matrices (which will be computed in full).

Here is an example:

import torch
print (torch.__version__)

_ = torch.manual_seed (2025)

fc = torch.nn.Linear (1024 * 3, 1024 * 3)          # connect all inputs to all outputs

x = torch.randn (1024, 3, requires_grad = True)    # a batch of coordinates

y = fc (x.flatten()).sigmoid().reshape (1024, 3)   # all elements depend on all elements

grad_outputs = torch.eye (1024 * 3).reshape (1024 * 3, 1024, 3)
batch_grad = torch.autograd.grad (y, x, grad_outputs = grad_outputs, is_grads_batched = True)[0]
print ('batch_grad.shape =', batch_grad.shape)

# select out the 3x3 "jacobian" matrices I think you want
batch_grad = batch_grad.reshape (1024, 3, 1024, 3).diagonal (dim1 = 0, dim2 = 2).permute (2, 0, 1)
print ('batch_grad.shape =', batch_grad.shape)

And here is its output:

2.6.0+cu126
batch_grad.shape = torch.Size([3072, 1024, 3])
batch_grad.shape = torch.Size([1024, 3, 3])

You can package your call to model as a function (that refers to some “global” variables)
and pass it to jacrev():

# x1 = whatever ...
# x2 = whatever ...
# mask1 = whatever ...
# mask2 = whatever ...

def applyModelToX1 (x_one):
    out1, _ = model (x_one, x2, mask1, mask2)
    return out1

fullJacobian = torch.func.javrev (applyModelToX1) (x1)

# snip out submatrices of fullJacobian, as desired

Best.

K. Frank

1 Like

Thanks Frank. Your code is right. I tested it in my actual implementation. Unfortunately, in my real code, batch_size is significantly large (about 10^6). Adding this additional dimension makes the real code unable to work.

I don’t know if there’s a way I can modify my model output into two versions. One is original output for standard backpropagation. The other, whose cross-sample derivative dependencies are eliminated or ignored, is going to be used only for Jacobian computation.

Although I’m not sure how to process the second output.

Is it a viable idea?

Hi Windows!

You don’t say how your real code doesn’t “work.” Does it give the wrong answer? Does
it crash? Do you run into memory problems? Does it take too long?

As an aside, it’s hard to give useful answers when you keep moving the goal posts. It
appears that your real use case is a far cry from the example you gave initially:

x = torch.randn (1024, 3, requires_grad = True)
y = torch.sin (x)

Does every element of out1, say out1[0, 0], depend on more or less every element
of x1? Furthermore, do the derivatives of out1[0, 0] depend on more or less every
element of x1?

Can you perform a single complete forward pass – with your real batch_size – with
autograd turned off, e.g:

with torch.inference_mode():
    out1, out2 = model (x1, x2, mask1, mask2)

Can you perform a full forward pass with autograd left on?

Can you perform a single backward pass (for a single scalar):

out1, out2 = model (x1, x2, mask1, mask2)
loss = out1[0, 0]
loss.backward()

You could then just keep the piece of the jacobian you need,
j = x1.grad[0].detach.clone(), and then loop over all the elements of out1,
performing a backward pass for each element. Before doing this, you could set
requires_grad = False for x2, mask1, mask2, and all of the parameters of
model to reduce the work and storage needed by the forward pass and backward
passes.

If you can get this to work, you could probably get some speed-up with judicious use
of vmap() and / or autograd.grad (..., is_grads_batched = True) so that your
python loop (if you have one at all) doesn’t have to loop over each element of out1
individually.

Good luck.

K. Frank