Hey all,

We have a model that has large capacity, but each training step only updates a tiny subset of those parameters. The forward pass behaves the same, it only uses a small subset of the total parameter space, though the parameters are allocated in a contiguous tensor (imagine looking up the required parameters in a table based on each input).

Is it possible to partially track gradients on each step to reduce peak VRAM use? Or does anyone have advice on how to structure this so it scales proportionally to the actual number of parameters updated?

Note that the parameter tensors are not “sparse”, all elements of the tensor have values, it’s just that many values don’t change for a given step - so I don’t /think/ SparseAdam applies, but maybe I’m wrong.

–

Jeremy