Handling OOM errors - question / feature suggestion

Thanks to nice people here, I’m aware that “model checkpointing” is supported now in “master”, and will definitely explore it, but I want to take a step back and ask why isn’t the following, at least conceptually simpler, approach implemented in pytorch and other APIs (or at least I don’t think it’s implemented…)

Imagine that you want your app running fully on the GPU. What if, every time that GPU memory runs out, we would evict a resource from GPU memory to system memory, and move it back when needed.
Now, I understand that allocation+memory copy can be very costly, however:

  1. Sometimes it’s better to pay this penalty then not having it supported at all
  2. It still should be faster than pure CPU mode, as I believe in quite a few occasions the calculation will not be “swallowed” by the allocation mem+copy times
  3. There are existing computation nodes with VERY fast memory transfer speed between gpu mem and system mem.

Note - this is not any form of “complaint”, I’m very grateful of pytorch already existing code, I just want to know if I’m missing something in the picture here, and wondering if such contribution will be interesting to anyone.

fyi https://github.com/pytorch/pytorch/pull/5313


Notice that:

  • I’m not talking about using “managed memory” in the sense that is mentioned there - everything can still only use cudaMalloc. The eviction and reloading of variables can be done at high level in the forward/backward passes.

  • It will not kick-in until GPU memory is over, so it will be identical in performance to existing system until GPU mem is done.

Anyway, I guess that the burden of proof is on me then, thanks for sharing the relevant thread :slight_smile: