Calculate individual gradients for each example in mini-batch?

Say I have an the output of y=model(x) where y.size() is [batch_dimensions, 1]. I would then like to calculate the individual gradients of each example in the output with respect to the model.parameters(). Naively, I could do:

grads = [torch.autograd.grad(y[batch], model.parameters(), retain_graph=True) batch in range(batch_dimensions)]

but this is serial and slow.

I know I can, however, do:

grads = torch.autograd.grad(y, model.parameters(), grad_outputs=torch.ones(y.size()))

which is an awful lot faster. However, this seems to sum the gradients across the batches, which is not quite what I want. How could I possibly achieve what I require efficiently?


Try this:

@Yaroslav_Bulatov Hi Can this module support LSTM? Thank you!

@roylight, could I ask about your use case for individual per-example gradients?

To answer your question, cybertronai/autograd-hacks don’t support PyTorch’s nn.LSTM out of the box, but if you write a custom LSTM (using nn.Linear, some activation functions, and a for-loop), cybertronai/autograd-hacks should be able to help (there is an implementation for per-sample-gradients of nn.Linear in there).

Indeed, it doesn’t support anything other than Conv/Linear, so would need some work. In ideal world, an autovectorization operation would be able to do this automatically, but right now it seems manual algebra is required.

@richard one common demand for per-example gradient norms is from privacy people – a datapoint with a large gradient could influence the final solution strongly, so you could “leak” information about that datapoint through the final model.

@Yaroslav_Bulatov we’re working on prototyping an auto-vectorization operation like JAX’s vmap ( in the hopes that it may be able to compute per-sample-gradients (among other use cases). There’s a slight problem where it doesn’t quite work for the per-sample-gradient use case, because of how PyTorch internals are organized, but we’re figuring that out.

I’m curious if you’ve come across other demands for per-example gradients. I’ve spoken with some folks in the differential privacy space but haven’t been able to find other use cases.

The other place is optimization research. There are thousands of papers whose implementation would benefit from efficient access to per-example gradients. Few examples –

  • Optimize faster: cluster gradients, then undersample large clusters, paper.
  • Find potentially mislabeled examples by checking if gradients change rapidly, paper.
  • Predict generalization by measuring sharpness of feasible solution. This needs variance of per-example Hessians, which are gradients of the gradient function. paper.

Popular optimizers like AdaGrad and natural gradient were formulated in terms of per-example gradients. Because frameworks didn’t provide efficient way to do this, all practical implementations substituted batch gradients as an approximation.
Their derivatives (Adam, KFAC, Shampoo, full AdaGrad) ended up inheriting this approximation. What if this approximation is bad? Without per-example gradient support, it’s too much work to check.

Basically this feature would make PyTorch a more friendly research platform.

PS: efficient per-example gradients (as well as per-example Hessians, Hessian vector products, per-example gradient norms, etc) are all just special cases of einsum optimization. Computation for these quantities comes down to a series of multiplications and additions, and einsum optimizer is what tells you how to properly rearrange them.

An example of using this perspective to improve on PyTorch’s Hessian vector product and to compute batch of Hessian norms efficiently


Thank you for the detailed reply! I’ll take a look through the papers. This is really helpful - we’re also exploring other mechanisms to compute per-example-gradients (that aren’t through an auto-vectorization operator) and investigating use cases will help out with the decision making.

1 Like

Another project that is worth mentioning: (from "BackPACK: Packing more into backprop).