Can't vmap autograd.grad over outputs

Hi! I’m trying to compute per-sample gradients efficiently. I know from the tutorial about per-sample gradients that I can vmap a call to torch.func.grad.

However, the function returned by torch.func.grad does both a forward and a backward pass. In my specific case, I have to do a forward pass anyway, which creates the autograd graph, so it feels like a waste to not re-use this autograd graph.

I would like to know if there is a way to vmap a call to torch.autograd.grad instead of torch.func.grad, so that the already-existing autograd graph can be reused, and so that no extra forward pass is done.

I tried the following:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

batch_size = 64

inputs = torch.randn((batch_size, 5))
targets = torch.randn((batch_size, 1))

model = nn.Linear(5, 1)
params = list(model.parameters())

outputs = model(inputs)  # shape: [batch_size, 1]
losses = F.mse_loss(outputs, targets, reduction="none").squeeze()  # shape: [batch_size]

def compute_one_gradient(loss: Tensor) -> tuple[Tensor, ...]:
    return torch.autograd.grad(loss, params)

grads = torch.vmap(compute_one_gradient)(losses)

But I end up with the following error:

tests/unit/autogram/test_per_sample_grads.py:65 (test_per_sample_grads)
def test_per_sample_grads():
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        from torch import Tensor
    
        batch_size = 64
    
        inputs = torch.randn((batch_size, 5))
        targets = torch.randn((batch_size, 1))
    
        model = nn.Linear(5, 1)
        params = list(model.parameters())

        outputs = model(inputs)  # shape: [batch_size, 1]
        losses = F.mse_loss(outputs, targets, reduction="none").squeeze()  # shape: [batch_size]
    
        def compute_one_gradient(loss: Tensor) -> tuple[Tensor, ...]:
            return torch.autograd.grad(loss, params)
    
>       grads = torch.vmap(compute_one_gradient)(losses)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

unit/autogram/test_per_sample_grads.py:86: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../.venv/lib/python3.13/site-packages/torch/_functorch/apis.py:202: in wrapped
    return vmap_impl(
../.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py:334: in vmap_impl
    return _flat_vmap(
../.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py:484: in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
unit/autogram/test_per_sample_grads.py:84: in compute_one_gradient
    return torch.autograd.grad(loss, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.venv/lib/python3.13/site-packages/torch/autograd/__init__.py:502: in grad
    result = _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

t_outputs = (BatchedTensor(lvl=1, bdim=0, value=
    tensor([4.2187e-01, 7.2812e-04, 3.7193e-02, 3.6601e+00, 4.0343e+00, 6.0330e-0...03, 7.4318e-01,
            2.8548e-01, 4.4032e-02, 7.8910e-04, 7.2651e-01],
           grad_fn=<SqueezeBackward0>)
),)
args = ((None,), False, False, (Parameter containing:
tensor([[ 0.1479,  0.4471,  0.2320,  0.2780, -0.1565]], requires_grad=True), Parameter containing:
tensor([0.2146], requires_grad=True)), False)
kwargs = {'accumulate_grad': False}, attach_logging_hooks = False

    def _engine_run_backward(
        t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
        *args: Any,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, ...]:
        attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
        if attach_logging_hooks:
            unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
        try:
>           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                t_outputs, *args, **kwargs
            )  # Calls into the C++ engine to run the backward pass
E           RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

../.venv/lib/python3.13/site-packages/torch/autograd/graph.py:824: RuntimeError

I think the reason is that when vmap splits the losses tensor to parallelize over it, it does so without creating a grad_fn, so the loss tensor provided to compute_one_gradient does not require grad anymore.

Is there any proper way, or even any trick, to do what I want to do in parallel and without any extra forward pass? (I think torch.func.grad, torch.func.vjp, torch.func.jacobian, torch.func.jacrev, etc. all require an extra forward pass). It almost seems like what I need is the is_grads_batched parameter of torch.autograd.grad, but this is for batched grad_outputs, not for batched outputs.

Thanks in advance for the help!

Would backpack help here?

1 Like

Thanks for the suggestion! I just tried with backpack’s BatchGrad, but it’s actually slightly slower than my current implementation with the extra forward pass :frowning: It also seems to come with a few bugs that I had to work around, so I would rather not depend on it.

Ah okay. Sorry, but i have never stress tested backpack to be honest.

I found this bit

  • is_grads_batched (bool, optional) – If True, the first dimension of each tensor in grad_outputs will be interpreted as the batch dimension. Instead of computing a single vector-Jacobian product, we compute a batch of vector-Jacobian products for each “vector” in the batch. We use the vmap prototype feature as the backend to vectorize calls to the autograd engine so that this computation can be performed in a single call. This should lead to performance improvements when compared to manually looping and performing backward multiple times. Note that due to this feature being experimental, there may be performance cliffs. Please use torch._C._debug_only_display_vmap_fallback_warnings(True) to show any performance warnings and file an issue on github if warnings exist for your use case. Defaults to False.

From: torch.autograd.grad — PyTorch 2.7 documentation

Still trying to fully understand it, but is this something you could potentially use? Did not know of this option till now.

I also tried this, but it really is for batched grad_outputs and not outputs.

For instance, you can do:

torch.autograd.grad(
    outputs,  # One tensor of shape: [10]
    inputs,  # any shapes
    grad_outputs,  # One batch of grad_outputs, of shape: [batch_size, 10]
    is_grads_batched=True
)

My case is similar to this, except that both outputs and grad_outputs are batched (e.g. the shape of outputs is [batch_size, 10]), which makes the call to torch.autograd.grad fail, even with is_grads_batched=True.

1 Like

I will have a deeper read over the weekend and try to get back to with something reasonable hopefully. Thanks for the responses.

1 Like

hii guys, I think I have a few ways to solve that

first and most similar to your original code is using a vmapped autograd.grad
in this solution I’m vmapping a call to a function that computes the grads of a single loss w.r.t to the params using autograd.grad
vmapping over an identity matrix that I’m using to select a single loss each time by pasisng it as grad_outputs + retaining the graph so that once the forward pass is computed we’re parallelizing the backward pass using the same graph without reconstruction


import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size = 64

inputs = torch.randn((batch_size, 5))
targets = torch.randn((batch_size, 1))

model = nn.Linear(5, 1)
params = list(model.parameters())

outputs = model(inputs)  # shape: [batch_size, 1]
losses = F.mse_loss(outputs, targets, reduction="none").squeeze()  # shape: [batch_size]

bs = losses.shape[0]
identity = torch.eye(bs)
def compute_sample_grad(loss_selector):
    grads = torch.autograd.grad(
        outputs=losses,
        inputs=params,
        grad_outputs=loss_selector,
        retain_graph=True
    )
    
    return {
        name: grad for (name, _), grad in zip(model.named_parameters(), grads)
    }
    
per_sample_grads = torch.vmap(compute_sample_grad)(identity)

a second solution would be using just using jvp (I believe it runs the forward pas just once, then reuses the graph it computed)

I’ve also found this tutorial that suggest using a functional call link

lastly, if you’re running pytorch 2.7 +, I think this is the wrapper api is what you’re looking for docs

1 Like

Hi @valerian.rey, why don’t you use torch.func.grad with the has_aux flag? So, you can compute the forward and backward in a single call and return the forward with the backward via a singletorch.func.grad call? (rather than only using torch.func.grad for the backward?

Hi! Thanks for your suggestions!

First solution

Your first solution is quite good in the general case (in fact it’s very similar to what we use internally in torchjd) but it is also quite inefficient (in terms of memory) when the number of losses grows.

The thing is that internally, this will compute the gradient of each loss with respect to all layer outputs (as these are needed to compute the gradients wrt the model parameters). In particular, it will compute the gradient of loss1 (the loss associated to batch element 1) with respect to output1_2 (the output of the first layer for batch element 2). But if your model operates on each batch element independently (which is the case unless you have BatchNorm layers or other weird things mixing elements of the batch), this gradient will always be 0: the loss for batch element 1 does not depend at all on batch element 2 (or on any of its activations throughout the network).

So while this method works, it has to compute internal jacobian matrices that are very big and sparse, which takes a ton of memory.

What we currently do is to just compute the gradient of the sum of the losses with respect to each layer’s output tensor, and then use it to compute the jacobian of the losses with respect to that layer’s parameters.
For a given layer, the output will have shape [batch_size, …], and so the gradient (call it g) of the sum of losses wrt this output will also have shape [batch_size, …]. Further, the gradient of the i’th loss wrt this output also has shape [batch_size, …], but all of its rows j != i will be 0 and its i’th row will be g[i]. All of this means that g contains all the information we care about: despite being a simple gradient, it has the same information as a jacobian.

The problem arises when we want to use this g as the grad_outputs to compute the jacobian of the losses wrt the layer’s parameters. We can tell torch.autograd.grad that the grad_outputs are batched, using is_grads_batched=True. However, we cannot tell it that the outputs are also batched, so it doesn’t work (btw I opened an issue to add an is_outputs_batched parameter to torch.autograd.grad).

Your first solution is still applicable here, in a slightly modified way: instead of giving a simple identity matrix as grad_output, we could give a “diagonalized” version of the gradient wrt to the output of the layer: its first element would be the gradient of the first loss wrt the output of the layer (including all of its rows of zeros), its second element would be that of the second loss, and so on. This works, but such a tensor is extremely large, which blows memory usage up.

Second solution

jvp is not really usable in my case because it does not return a jvp_func: it gives directly the output of the model and the grad. But (vmapped) vjp is a viable solution to compute both the batched output of a layer and its (batched) vjp_func, which we can store and then use during the backward pass to obtain the jacobians wrt the parameters. See the code of this comment for the implementation. This is not the solution we ended up using because despite being better in theory (single forward pass), it is a bit slower.

Tutorial about per-sample gradients

Sadly, we can’t directly use what is given in this tutorial: to save memory, we want to compute per-sample gradients on a parameter-by-parameter basis, compute and accumulate some minimally sufficient statistic about those per-sample gradients, and “drop” them (so that they don’t use memory anymore). So we really need a fine-grained solution that doesn’t just give the per-sample gradients for the full model.

1 Like

Well, with torch.func.vjp, we can directly obtain both the output and the vjp_func, which is, I think, what you’re suggesting. We can then vmap this torch.func.vjp to do that for a whole batch in parallel.

In our case, however, we want to be able to do some extra operations after obtaining the per-sample gradients for each parameter, and we want to “drop” them right after that (to save memory). So we would need to apply the forward pass of each layer through a different call to vmap(torch.func.vjp(...)), and store all of the vjp_funcs obtained along the way. This is a viable solution, but it seems a bit slower than not storing anything and just re-doing the call to vmap(torch.func.vjp(...)) each time we need it, which does both a (extra) forward pass and a backward pass.

So yeah, in theory you’re right that using a torch.func function transform to do both the forward and the backward makes sense, and it’s actually doable, but it’s somehow a bit slower :confused:

For reference, the implementation for this is available in this comment.

(post deleted by author)

With is_grads_batched you could do this:

def jacobian_batched(output: list[torch.Tensor], wrt: list[torch.Tensor], create_graph=False):
    flat_input = torch.cat([i.reshape(-1) for i in output])
    return torch.autograd.grad(
        flat_input,
        wrt,
        torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
        create_graph=create_graph,
        is_grads_batched=True,
    )

jacobian_batched([losses], params)

This is batched vjp, I haven’t done much testing, but for a hessian seems similar in performance to torch.func.hessian and torch.autograd.functional.hessian