I have a training pipeline which offloads various components (model, model ema, optimizer) to CPU at various training step stages, and does so asynchronously (e.g. for each data buffer, calling buffer.to('cpu', non_blocking=True).
Making these transfers non-blocking results in significant speed increases (almost 2x). However, it also blows up the CPU RAM usage and results in my training process getting OOMKilled. My guess is that due to the large number of in-flight transfers to/from CPU, the process reserves more RAM than it really needs.
Is there any way to achieve a middle ground here? I could just make all the transfers blocking and the training runs fine, but it would be cool to somehow give a hint to pytorch about maximum RAM usage in this async setting.