Hi, I’m wondering how to deal with occasional OOM error happend during DDP backward.
For forward, oom can be captured simply by a try-catch statement. For backward, however, loss.backward() performs gradient calculation and the registered hooks perform gradient reduction at the same time.
Is it possible to hang due to oom errors during backward in several process so that the other successful processes keep waiting for them? If so, is there a nice way to recover from this problem?
Yes, it is. If one process hit OOM and skipped/reran the the backward pass, it would cause de-synchronization across processes in the same group, which would lead to hang or crash.
If so, is there a nice way to recover from this problem?