Hi! I’m trying to compute per-sample gradients efficiently. I know from the tutorial about per-sample gradients that I can vmap a call to torch.func.grad
.
However, the function returned by torch.func.grad
does both a forward and a backward pass. In my specific case, I have to do a forward pass anyway, which creates the autograd graph, so it feels like a waste to not re-use this autograd graph.
I would like to know if there is a way to vmap a call to torch.autograd.grad
instead of torch.func.grad
, so that the already-existing autograd graph can be reused, and so that no extra forward pass is done.
I tried the following:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
batch_size = 64
inputs = torch.randn((batch_size, 5))
targets = torch.randn((batch_size, 1))
model = nn.Linear(5, 1)
params = list(model.parameters())
outputs = model(inputs) # shape: [batch_size, 1]
losses = F.mse_loss(outputs, targets, reduction="none").squeeze() # shape: [batch_size]
def compute_one_gradient(loss: Tensor) -> tuple[Tensor, ...]:
return torch.autograd.grad(loss, params)
grads = torch.vmap(compute_one_gradient)(losses)
But I end up with the following error:
tests/unit/autogram/test_per_sample_grads.py:65 (test_per_sample_grads)
def test_per_sample_grads():
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
batch_size = 64
inputs = torch.randn((batch_size, 5))
targets = torch.randn((batch_size, 1))
model = nn.Linear(5, 1)
params = list(model.parameters())
outputs = model(inputs) # shape: [batch_size, 1]
losses = F.mse_loss(outputs, targets, reduction="none").squeeze() # shape: [batch_size]
def compute_one_gradient(loss: Tensor) -> tuple[Tensor, ...]:
return torch.autograd.grad(loss, params)
> grads = torch.vmap(compute_one_gradient)(losses)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
unit/autogram/test_per_sample_grads.py:86:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../.venv/lib/python3.13/site-packages/torch/_functorch/apis.py:202: in wrapped
return vmap_impl(
../.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py:334: in vmap_impl
return _flat_vmap(
../.venv/lib/python3.13/site-packages/torch/_functorch/vmap.py:484: in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
unit/autogram/test_per_sample_grads.py:84: in compute_one_gradient
return torch.autograd.grad(loss, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.venv/lib/python3.13/site-packages/torch/autograd/__init__.py:502: in grad
result = _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
t_outputs = (BatchedTensor(lvl=1, bdim=0, value=
tensor([4.2187e-01, 7.2812e-04, 3.7193e-02, 3.6601e+00, 4.0343e+00, 6.0330e-0...03, 7.4318e-01,
2.8548e-01, 4.4032e-02, 7.8910e-04, 7.2651e-01],
grad_fn=<SqueezeBackward0>)
),)
args = ((None,), False, False, (Parameter containing:
tensor([[ 0.1479, 0.4471, 0.2320, 0.2780, -0.1565]], requires_grad=True), Parameter containing:
tensor([0.2146], requires_grad=True)), False)
kwargs = {'accumulate_grad': False}, attach_logging_hooks = False
def _engine_run_backward(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
E RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
../.venv/lib/python3.13/site-packages/torch/autograd/graph.py:824: RuntimeError
I think the reason is that when vmap splits the losses
tensor to parallelize over it, it does so without creating a grad_fn
, so the loss
tensor provided to compute_one_gradient
does not require grad anymore.
Is there any proper way, or even any trick, to do what I want to do in parallel and without any extra forward pass? (I think torch.func.grad
, torch.func.vjp
, torch.func.jacobian
, torch.func.jacrev
, etc. all require an extra forward pass). It almost seems like what I need is the is_grads_batched
parameter of torch.autograd.grad
, but this is for batched grad_outputs
, not for batched outputs
.
Thanks in advance for the help!