Dealing with varying input size, I catch OOM exceptions during training (in my setting roughly 1 in few hundred minibatches). Due to domain specific reasons, I prefer not to crop/resize inputs to a constant size. Also, there is not a clear way to know in advance which input sizes will cause an OOM.
This is generally fine by me as long as I can recover from the OOM events and continue training.
If I detect an OOM event, I’m “cleaning up” using torch.cuda.empty_cache()
, zero gradients, and then continue training as usual. This works great in a non-distributed setup, but creates problems in a distributed setting.
note - I am following the suggested way to deal with OOM as mentioned here:
To deal with OOM in a distributed setting, I do something like this:
if problems_occured:
success_tens = torch.ones(0)
else:
success_tens = torch.ones(1)
dist.all_reduce(success_tens, op=dist.reduce_op.SUM) ###error happens here
and then, only if success_tens reached the size of the world_size, I do an additional all_reduce over the gradients to sum them.
This is to make sure that all workers succeeded in calculating their own gradient before combining the gradient.
However, after I catch the OOM event in the worker that caught this OOM event I get the following error:
miniconda3/envs/py36torch/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 838, in all_reduce
work.wait()
RuntimeError: [enforce fail at /opt/conda/conda-bld/pytorch_1544174967633/work/third_party/gloo/gloo/allreduce.cc:29] opts.elements > 0
note: as can be seen in the error message - I’m currently using gloo as the distributed backend.
Any suggestions on how to solve this are very welcome
Note - while I’m currently using a simple syncronized distributed gradients calculation, I’m open to any suggestions, as long as they help survive occasional (relatively rare) OOM events.