Partial Gradient Tracking?

Hey all,

We have a model that has large capacity, but each training step only updates a tiny subset of those parameters. The forward pass behaves the same, it only uses a small subset of the total parameter space, though the parameters are allocated in a contiguous tensor (imagine looking up the required parameters in a table based on each input).

Is it possible to partially track gradients on each step to reduce peak VRAM use? Or does anyone have advice on how to structure this so it scales proportionally to the actual number of parameters updated?

Note that the parameter tensors are not “sparse”, all elements of the tensor have values, it’s just that many values don’t change for a given step - so I don’t /think/ SparseAdam applies, but maybe I’m wrong.


I don’t believe your use case is easily doable since gradients will be calculated for the entire parameter and not for subtensors.
You could check this approach which uses subtensors with a frozen and trainable part and might fit your use case. I’ve also posted a way to replace built-in layers with this custom module using torch.fx in case that’s helpful.

Thanks! Yeah, I don’t think frozen subsets will work, because the subset is based on a hash and is a different set of scattered elements for every item in the batch