Calculate individual gradients for each example in mini-batch?

The other place is optimization research. There are thousands of papers whose implementation would benefit from efficient access to per-example gradients. Few examples –

  • Optimize faster: cluster gradients, then undersample large clusters, paper.
  • Find potentially mislabeled examples by checking if gradients change rapidly, paper.
  • Predict generalization by measuring sharpness of feasible solution. This needs variance of per-example Hessians, which are gradients of the gradient function. paper.

Popular optimizers like AdaGrad and natural gradient were formulated in terms of per-example gradients. Because frameworks didn’t provide efficient way to do this, all practical implementations substituted batch gradients as an approximation.
Their derivatives (Adam, KFAC, Shampoo, full AdaGrad) ended up inheriting this approximation. What if this approximation is bad? Without per-example gradient support, it’s too much work to check.

Basically this feature would make PyTorch a more friendly research platform.

PS: efficient per-example gradients (as well as per-example Hessians, Hessian vector products, per-example gradient norms, etc) are all just special cases of einsum optimization. Computation for these quantities comes down to a series of multiplications and additions, and einsum optimizer is what tells you how to properly rearrange them.

An example of using this perspective to improve on PyTorch’s Hessian vector product and to compute batch of Hessian norms efficiently

2 Likes