The common way to compute multiple different gradients is to use the autograd.grad function multiple times. But it does accept multiple inputs, making me think this means it can compute batches.
Here is some example code:
import torch
from torch.autograd import grad
torch.set_grad_enabled(True)
x = torch.rand(1, 3) # shape (..., 3)
x.requires_grad = True
y = x @ torch.rand(3, 3) # shape (..., 3)
assert y.grad_fn
def attempt_1(x, y):
y0_dx, = grad([y[..., 0]], [x], [torch.ones_like(y[..., 0])], create_graph=True)
y1_dx, = grad([y[..., 1]], [x], [torch.ones_like(y[..., 1])], create_graph=True)
y2_dx, = grad([y[..., 2]], [x], [torch.ones_like(y[..., 2])], create_graph=True)
return y0_dx, y2_dx, y1_dx
def attempt_2(x, y):
y0_dx, y1_dx, y2_dx = grad(
[y[..., 0], y[..., 1], y[..., 2]],
[x, x, x],
[torch.ones_like(y[..., 0])] * 3,
create_graph=True,
)
return y0_dx, y2_dx, y1_dx
print(attempt_1(x, y))
print(attempt_2(x, y))
attempt_1
works as intended, but is slow.
attempt_2
should to my understanding be equivalent, but is incorrect.
Even worse: all its outputs are the same!
How can i batch the multiple grad
calls as seen in attempt_1
into something like in attempt_2
?