Handeling OOM in backward pass using DDP

Thanks for filing the GitHub issue! Just to confirm the scenario you are seeing, sometimes you see an OOM in the fwd pass (which is handled by your try-catch block), whereas an OOM in the bwd pass results in the RuntimeError you posted.

Are you able to complete an iteration of training without seeing an OOM? If not, the runtime error may actually be due to some value returned by the fwd function that’s not used in the loss computation.

Regarding future debugging, first here is another thread about why the torch.cuda.empty_cache() function is not recommended. I can think of the following ways to get around the OOM issue in a more robust way:

  • Use Model Parallelism. For CPU-based models you can check out the RPC framework (We are working on robust GPU support for the RPC framework). Otherwise you can split the model manually call the forward functions on each shard and move activations around using .to(). Here is a recent question about this.
  • Try reducing the batch size
  • Use an optimizer that needs to store less local state (SGD vs. Adam)