Issues training models with nn.Embeddings with sparse gratients

I’m working with a model that uses a rather large instance of nn.Embedding (the weight matrix is several million rows by several hundred columns). This, along with other model parameters, fits into my GPU memory, but unsurprisingly I run out of memory in training during backpropagation unless I construct the embedding to make the gradient with respect to the weights a sparse tensor. That’s not a problem in and of itself, as I can use optim.SGD to do the training. However, in doing so I have noticed three issues I’d like to ask about:

(1) Momentum (standard or Nesterov) in optim.SGD doesn’t seem to work with sparse gradients. I can work around this by either not using momentum, or using patches I’ve written to the optimizer to enable this. These changes are relatively straightforward; it’s just a matter of allocating the momentum buffer like the gradient (i.e. sparse when the gradient is sparse), and adding a call or two to coalesce the gradient.

(2) nn.utils.clip_grad_norm_ also seems to be broken when using sparse gradients. Again I have been able to work around this by using an edited version of this code which detects sparse gradients and handles them appropriately; the changes are fairly simple.

(3) Weight decay is not implemented for sparse gradients. This is understandable, as the regularization term added to the loss for weight decay cannot have a sparse gradient, by definition. I’ve written some code to apply weight decay only to those entries in a parameter for which the gradient of the unregularized loss is non-zero, which gives some of the effect of true weight decay. This code is not quite so simple.

Wait, you say, where are the questions? Mainly I’d like to know if (1) and (2) are known issues. If not, can anyone point me to the appropriate venue to report them (assuming this isn’t it). Also, if anyone is aware of better workarounds than those I described above I’d love to hear about them. I can probably get permission from my employers to share some of the patches I’ve made; assuming that is the case would there be any interest from the PyTorch developers?

Thanks for listening, and let me know if you have any questions.

Robert E. Beaudoin

I think it is (was at some point) known. Since I have the same problem I encourage you to rise this issue on: https://github.com/pytorch/pytorch/issues