Getting elementwise gradients is slow

I’m trying to get the gradient of each sample like that:

grads = torch.autograd.grad(losses, model.parameters(), torch.eye(len(losses), device=losses.device), retain_graph=True, is_grads_batched=True, allow_unused=True)

where losses is a vector of losses, one for each sample.
It works. However, it is significantly slower than computing the gradient of the mean loss like that:

torch.autograd.grad(losses.mean(), model.parameters(), retain_graph=True, allow_unused=True)

My guess was that these should execute with similar runtime because both do the same calculations. The second code snippet also has to do the average.

It is suppose to be slower?
Can I achieve my goal in a faster way?


Hi Liran!

This line of code computes a batch of gradients, one for each element of
losses. (It should be faster than a python loop because autograd.grad
presumably uses vmap() vectorization under the hood, but it still performs
multiple gradient computations.)

This line of code, however, computes just a single gradient of the scalar
quantity losses.mean(), so it really is doing less work.

To reiterate, the two versions are not doing the same calculations. The
first computes a batch of gradients, while the second computes just a
single gradient.

This is true. But by averaging the losses together, you reduce them to a
single scalar for which you compute a single gradient, rather than a batch.

Possibly, but no guarantees. Depending on your use case it might be
possible to rearrange your computation to use pytorch’s so-called
forward-mode automatic differentiation. This could be faster, but could
also be slower.


K. Frank

1 Like