I would like to use torch.autograd.grad to compute the Jacobian matrix for a batch of input.

Say I have a function `y=f(x)` where `y` has the same dimension as `x`. If I have a batch of input `x`, denoted as x₁, x₂, …, xₙ, where the batch size is `n`, I want to compute the quantity

``````∑ᵢ trace(∂f(xᵢ)/∂xᵢ)
``````

Namely for each n x n Jacobian matrix ∂f(xᵢ)/∂xᵢ, I want to compute the trace of this matrix, and then sum up over the data index `i`.

``````import torch

def foo(f: torch.nn.Module, xi: torch.Tensor):
"""
Compute trace(∂f(xᵢ)/∂xᵢ)
Here xi is a single input data, not a batch or data.
"""
x_dim = xi.numel()

if __name__ == "__main__":
x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
trace_sum = 0
for i in range(x_batch.shape):
xi = x_batch[i].clone()
trace_sum += foo(f, xi)
``````

Is there a way that I don’t need to loop over the input batch dimension through the for loop? Thanks a lot for your kind help in advance!

P.S I also figured that the desired quantity ∑ᵢ trace(∂f(xᵢ)/∂xᵢ) is just

``````∑ᵢ ∑ⱼ∂f(xᵢ)[j]/∂xᵢ[j]
``````

So should be able to do this

``````def foo_batch(f: torch.nn.Module, x):
"""
Compute ∑ᵢ trace(∂f(xᵢ)/∂xᵢ)
x is a batch of inputs.
"""
y = f(x)
y_flat = y.reshape((-1,))
x_flat = x.reshape((-1,))
# This is memory inefficient as it requires a huge identity matrix as grad_outputs.

if __name__ == "__main__":
x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
print(foo_batch(f, x_batch))
``````

But when I run this code, I got the runtime error

``````RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior
``````

I am not sure what this error means, any help would be appreciated.

Hi Hongkai!

`y_flat` is a function of `x` (as is `y`). However, `y_flat` is not a function of
`x_flat` in the sense of the chain of computations recorded by pytorch’s
computation graph. (Because `x_flat` is an invertible function of `x`, we
can understand `y_flat` to be implicitly a function of `x_flat`, but autograd

because `y_flat` is not (according to the computation graph) a function of
`x_flat`.

You may tweak `foo_batch()` so that `y_flat` is a function of `x_flat` (and so
that `x_flat` becomes a leaf node of the computation graph).

As an aside, it might be faster or more memory efficient to use jacrev or
jacobian() to compute your jacobian matrix.

Consider:

``````>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> def foo_batchB (f: torch.nn.Module, x):
...     x_flat = x.reshape((-1,))
...     x_flat.requires_grad = True   # x_flat will now be the leaf of the computation graph
...     y = f (x_flat.reshape (x.shape))
...     y_flat = y.reshape((-1,))
...     # This is memory inefficient as it requires a huge identity matrix as grad_outputs.
...
>>> f = torch.nn.Sequential (torch.nn.Linear (2, 3), torch.nn.ReLU(), torch.nn.Linear (3, 2))
>>> x_batch = torch.tensor ([[1, 2.], [3, 4.], [5, 6.]])
>>>
>>> foo_batchB (f, x_batch)
>>>
>>> torch.func.jacrev (f) (x_batch).reshape (x_batch.numel(), x_batch.numel()).trace()
>>>
>>> torch.autograd.functional.jacobian (f, x_batch, vectorize = True).reshape (x_batch.numel(), x_batch.numel()).trace()
tensor(-0.0179)
``````

[Edit: Note that using `foo_batchB()`, `javrev`, or `jacobian()` with `x_batch`
will compute the “block-off-diagonal” elements of the full jacobian matrix (even
though they are all zero). Your original for-loop version avoids this inefficiency
(although it does compute the full per-batch-element jacobian matrices whose
individual off-diagonal elements get ignored when you compute `.trace()`). So
your for-loop version could well be the most efficient, especially as the batch
size becomes larger. It’s conceivable that you could “vectorize” your for-loop
version using vmap(), but I’ve never used `vmap()`.]

Best.

K. Frank

1 Like

Thanks a lot @KFrank for your kind help! The explanation is super clear, much appreciated!

In case it helps future users, when using `vmap`, the code looks like this

``````def foo(f: torch.nn.Module, xi: torch.Tensor):
x_dim = xi.numel()