Imbalanced GPU memory usage training LSTM

I’m training a language model using the code here:

I have made some slight changes so that the model can be trained across multiple GPUs. However, the GPU memory usage is extremely imbalanced.

I can understand that one GPU is set to gather and store all outputs. I wonder if there is any way I can balance the memory usage? Or can I set one GPU for gathering outputs and the rest for training on batches?


1 Like

I found the reason, it is because we collect the output back to one gpu and calculate loss there. If move loss calculation into model.forward(), the problem is resolved.

Does this actually speed things up? I am facing a similar issue. Thanks

I think Pytorch should fix this by moving parameter saving and back prop to CPU, just like what Tensorflow does. I will switch to TF.