Is_grads_batched

Is there examples of using the parameter is_grads_batched in torch.autograd.grad? Thanks!

If I understand the usage correctly, you could avoid writing a for loop as internally vmap will be used:

x = torch.randn(2, 2, requires_grad=True)

# Scalar outputs
out = x.sum()  # Size([])
batched_grad = torch.arange(3)  # Size([3])
grad, = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)

# loop approach
grads = torch.stack(([torch.autograd.grad(out, x, torch.tensor(a))[0] for a in range(3)]))

(example uses the internal test as the base)

2 Likes

Thanks a lot for the quick reply! That is exactly the answer I was looking for!

May I follow up and ask whether this parameter can help give me a batch of jacobians? For example, see the commented code below

x = torch.randn(2, 2) #This is my input of size 2 and batch_size 2
out = model(x) #I want Jacobian wrt the logits, so let's say out is of size 2 (batch_size) x 10 (logits)
vs = torch.eye(M) #This is the vector in VJP, I need to vmap 10 times to obtain the full jacobian
vs = torch.stack([vs]*2, 0).permute(1,0,2) #Because my batch size is 2
grad, = torch.autograd.grad(out, model.parameters(), (vs,), is_grads_batched=True)

If my model has p parameters, I would expect this code to give me jacobian of size 2(batch_size) x 10 (logits) x p (parameters). However, the batch_size dimension is always missing in the jacobian output, so I am not sure whether I am doing something wrong?