I am facing a very weird OOM error with my current training. The OOM error happens systematically during the forward pass of the last (or second last, not sure) batch of an epoch. It happens independent of training size. It happens before validation.
I am logging the GPU memory consumption via nvidia-smi during training. During the training epoch the memory consumption stays constant, so I doubt it’s a typical memory leak (caused e.g. by a missing .detach() call). However, when running the last batch, the memory consumption suddenly starts to increase in the forward pass. This can be seen by printing the GPU memory consumption during different steps in the forward pass.
Is there something special happening at epoch end, when the queue size of the dataloader approaches zero? Did anybody else face this phenomenon?
P.s.: While I said the error happens systematically, the training did work for at one time for a couple of epochs, so there does seem some randomness to it. However, it raised the OOM error at epoch end approximately a dozen times now.
P.P.s.: Training pipeline is a modified monodepth2 and I’m getting the error at this point: https://github.com/nianticlabs/monodepth2/blob/master/trainer.py#L285 I am using only 1 GPU to train.
Data is pushed to the GPU not inside the Dataset, but just before the forward pass
By now I also ran a memory trace suggested in this thread: How to debug causes of GPU memory leaks? . However, somehow the training runs smoothly (and extremely slowly) for several epochs in this case… Could it be that at some point GPU memory is not deallocated quickly enough and that some wait statements would solve it?
Also forgot to mention: I am using an older version of monodepth, namely 0.4.1
What exactly do you mean with plain Pytorch code? What I can say is that the monodepth2 training runs with different parameter settings, e.g. different input size, batch size or network architectures. The current configuration is already pretty close to the GPU memory limit. But I will test today whether this memory uage spike in the last batch also happens for other configurations.