This line of code computes a batch of gradients, one for each element of losses. (It should be faster than a python loop because autograd.grad
presumably uses vmap() vectorization under the hood, but it still performs
multiple gradient computations.)
This line of code, however, computes just a single gradient of the scalar
quantity losses.mean(), so it really is doing less work.
To reiterate, the two versions are not doing the same calculations. The
first computes a batch of gradients, while the second computes just a
This is true. But by averaging the losses together, you reduce them to a
single scalar for which you compute a single gradient, rather than a batch.
Possibly, but no guarantees. Depending on your use case it might be
possible to rearrange your computation to use pytorch’s so-called forward-mode automatic differentiation. This could be faster, but could
also be slower.