Use torch.autograd.grad for a batch of inputs

I would like to use torch.autograd.grad to compute the Jacobian matrix for a batch of input.

Say I have a function y=f(x) where y has the same dimension as x. If I have a batch of input x, denoted as x₁, x₂, …, xₙ, where the batch size is n, I want to compute the quantity

∑ᵢ trace(∂f(xᵢ)/∂xᵢ)

Namely for each n x n Jacobian matrix ∂f(xᵢ)/∂xᵢ, I want to compute the trace of this matrix, and then sum up over the data index i.

Currently I use torch.autograd.grad to compute it in a for loop

import torch


def foo(f: torch.nn.Module, xi: torch.Tensor):
    """
    Compute trace(∂f(xᵢ)/∂xᵢ)
    Here xi is a single input data, not a batch or data.
    """
    xi.requires_grad = True
    x_dim = xi.numel()
    return torch.autograd.grad(f(xi), xi, grad_outputs = (torch.eye(x_dim),), retain_graph=True, create_graph=True, is_grads_batched=True)[0].trace()


if __name__ == "__main__":
    x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
    f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
    trace_sum = 0
    for i in range(x_batch.shape[0]):
        xi = x_batch[i].clone()
        xi.requires_grad = True
        trace_sum += foo(f, xi)

Is there a way that I don’t need to loop over the input batch dimension through the for loop? Thanks a lot for your kind help in advance!

P.S I also figured that the desired quantity ∑ᵢ trace(∂f(xᵢ)/∂xᵢ) is just

∑ᵢ ∑ⱼ∂f(xᵢ)[j]/∂xᵢ[j]

So should be able to do this

def foo_batch(f: torch.nn.Module, x):
    """
    Compute ∑ᵢ trace(∂f(xᵢ)/∂xᵢ)
    x is a batch of inputs.
    """
    x.requires_grad = True
    y = f(x)
    y_flat = y.reshape((-1,))
    x_flat = x.reshape((-1,))
    # This is memory inefficient as it requires a huge identity matrix as grad_outputs.
    return torch.autograd.grad(y_flat, x_flat, grad_outputs=(torch.eye(x_flat.numel()),), retain_graph=True, create_graph=True, is_grads_batched=True)[0].trace()

if __name__ == "__main__":
    x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
    f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
    print(foo_batch(f, x_batch))

But when I run this code, I got the runtime error

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior

I am not sure what this error means, any help would be appreciated.

Hi Hongkai!

y_flat is a function of x (as is y). However, y_flat is not a function of
x_flat in the sense of the chain of computations recorded by pytorch’s
computation graph. (Because x_flat is an invertible function of x, we
can understand y_flat to be implicitly a function of x_flat, but autograd
doesn’t know anything about this.)

The autograd documentation is rather opaque about this, but this error occurs
because y_flat is not (according to the computation graph) a function of
x_flat.

You may tweak foo_batch() so that y_flat is a function of x_flat (and so
that x_flat becomes a leaf node of the computation graph).

As an aside, it might be faster or more memory efficient to use jacrev or
jacobian() to compute your jacobian matrix.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> def foo_batchB (f: torch.nn.Module, x):
...     x_flat = x.reshape((-1,))
...     x_flat.requires_grad = True   # x_flat will now be the leaf of the computation graph
...     y = f (x_flat.reshape (x.shape))
...     y_flat = y.reshape((-1,))
...     # This is memory inefficient as it requires a huge identity matrix as grad_outputs.
...     return torch.autograd.grad(y_flat, x_flat, grad_outputs=(torch.eye(x_flat.numel()),), retain_graph=True, create_graph=True, is_grads_batched=True)[0].trace()
...
>>> f = torch.nn.Sequential (torch.nn.Linear (2, 3), torch.nn.ReLU(), torch.nn.Linear (3, 2))
>>> x_batch = torch.tensor ([[1, 2.], [3, 4.], [5, 6.]])
>>>
>>> foo_batchB (f, x_batch)
tensor(-0.0179, grad_fn=<TraceBackward0>)
>>>
>>> torch.func.jacrev (f) (x_batch).reshape (x_batch.numel(), x_batch.numel()).trace()
tensor(-0.0179, grad_fn=<TraceBackward0>)
>>>
>>> torch.autograd.functional.jacobian (f, x_batch, vectorize = True).reshape (x_batch.numel(), x_batch.numel()).trace()
tensor(-0.0179)

[Edit: Note that using foo_batchB(), javrev, or jacobian() with x_batch
will compute the “block-off-diagonal” elements of the full jacobian matrix (even
though they are all zero). Your original for-loop version avoids this inefficiency
(although it does compute the full per-batch-element jacobian matrices whose
individual off-diagonal elements get ignored when you compute .trace()). So
your for-loop version could well be the most efficient, especially as the batch
size becomes larger. It’s conceivable that you could “vectorize” your for-loop
version using vmap(), but I’ve never used vmap().]

Best.

K. Frank

1 Like

Thanks a lot @KFrank for your kind help! The explanation is super clear, much appreciated!

In case it helps future users, when using vmap, the code looks like this

def foo(f: torch.nn.Module, xi: torch.Tensor):
    x_dim = xi.numel()
    return torch.func.jacrev(f)(xi).reshape(x_dim, x_dim).trace()

batched_foo = torch.func.vmap(foo)
batched_foo(x_batch).sum()

Hi!
I follow your conversation and I have a similar problem, but why if f goes to one dimensional space raise this error and how I can make a gradient of f w.r.t. its input (maybe considering a batch input of size [n_batch, 2])?

RuntimeError: If `is_grads_batched=True`, we interpret the first dimension of each grad_output as the batch dimension. The sizes of the remaining dimensions are expected to match the shape of corresponding output, but a mismatch was detected: grad_output[0] has a shape of torch.Size([]) and output[0] has a shape of torch.Size([3]). If you only want some tensors in `grad_output` to be considered batched, consider using vmap.

You can easily replicate by defining:

f = torch.nn.Sequential (torch.nn.Linear (2, 3), torch.nn.ReLU(), torch.nn.Linear (3, 1))