Surviving OOM events in distributed training

Dealing with varying input size, I catch OOM exceptions during training (in my setting roughly 1 in few hundred minibatches). Due to domain specific reasons, I prefer not to crop/resize inputs to a constant size. Also, there is not a clear way to know in advance which input sizes will cause an OOM.
This is generally fine by me as long as I can recover from the OOM events and continue training.

If I detect an OOM event, I’m “cleaning up” using torch.cuda.empty_cache(), zero gradients, and then continue training as usual. This works great in a non-distributed setup, but creates problems in a distributed setting.

note - I am following the suggested way to deal with OOM as mentioned here:

To deal with OOM in a distributed setting, I do something like this:

            if problems_occured:
                success_tens = torch.ones(0)
                success_tens = torch.ones(1)

            dist.all_reduce(success_tens, op=dist.reduce_op.SUM)  ###error happens here

and then, only if success_tens reached the size of the world_size, I do an additional all_reduce over the gradients to sum them.
This is to make sure that all workers succeeded in calculating their own gradient before combining the gradient.

However, after I catch the OOM event in the worker that caught this OOM event I get the following error:

miniconda3/envs/py36torch/lib/python3.6/site-packages/torch/distributed/", line 838, in all_reduce
RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1544174967633/work/third_party/gloo/gloo/] opts.elements > 0

note: as can be seen in the error message - I’m currently using gloo as the distributed backend.

Any suggestions on how to solve this are very welcome :slight_smile:

Note - while I’m currently using a simple syncronized distributed gradients calculation, I’m open to any suggestions, as long as they help survive occasional (relatively rare) OOM events.

@ptrblck any idea? the tl;dr is that I can survive OOM events in a none distributed setting (just discarding this training minibatch, cleaning up memory and continuing), but using distributed setting I can’t.
This is important for me as I’m in medical imaging setting in which resolution is important, and cropping is too destructive.
Any suggestions and/or pointing me to relevant people if necessary is very welcome :slight_smile:

Sorry, I’m not really familiar with distributed training and gloo, so I can’t give any useful input to fix this issue. :confused:

However, have you thought about other approaches to avoid the OOM issues?
torch.utils.checkpoint might be worth a try (although I’m not sure how it behaves in a distributed setup) or NVIDIA’s apex - mixed precision training.

Let me know, this would be an option.

Thanks! I’m already using both checkpointing and mixed precision, which helped to make the OOM events pretty rare, but they still exist here and there. Perhaps it’s reasonable to consider this “a bug” or feature request and just report it on the github channel.

Hi @yoelshoshan! The error message indicates that the tensor that you’re passing to allreduce is empty. Is it possible that the “success_tens” itself is somehow empty? Not sure this is possible, but since you’re already dealing with an OOM…

Hi! first of all thanks for trying to assist :slight_smile:
success_tens is not empty.

I managed to build NCCL and when using it as the backend for the distributed functions, this issue does not happen, so I believe that this is gloo specific problem.

OK, I just realized what’s going on here. I misunderstood the code snippet you list in the original post. If you see an OOM, you create an empty tensor. This is why the error triggers. Instead of torch.ones(0), you’ll want to use torch.zeros(1).

1 Like

Like you suggested, using torch.tensor(0.0) or torch.tensor(1.0) does not trigger that issue.

Thanks for the help! <3