Handling oom error during DDP backward

Hi, I’m wondering how to deal with occasional OOM error happend during DDP backward.

For forward, oom can be captured simply by a try-catch statement. For backward, however, loss.backward() performs gradient calculation and the registered hooks perform gradient reduction at the same time.

Is it possible to hang due to oom errors during backward in several process so that the other successful processes keep waiting for them? If so, is there a nice way to recover from this problem?

Yes, it is. If one process hit OOM and skipped/reran the the backward pass, it would cause de-synchronization across processes in the same group, which would lead to hang or crash.

If so, is there a nice way to recover from this problem?

Yep, TorchElastic is built to solve this issue. cc @Kiuk_Chung

1 Like

Have a look here at:

  1. https://pytorch.org/elastic/0.2.0/train_script.html - for instructions on how to write a “torchelastic compliant” train script
  2. https://pytorch.org/elastic/0.2.0/quickstart.html - for a quickstart on launching your script with torchelastic
1 Like

Thank you, I will try it out.

Thank you, I will have a try.