How to handle exception in DistributedDataParallel?

I’m using DistributedDataParallel to train my model. If one process met an exception and use the try...except block to catch the exception during forward then continue training with a new batch of data, all the process would hang (I guess that is because the fail of synchronization?). How can I handle exceptions in one process and continue training without hanging all the process? Thanks for the help!

All communication done through torch.distributed is collective, meaning all processes expect all their peers to participate in all collective calls they execute. If a single processes ends up not participating, the others will time out or raise an exception. The only way out of this is to let all processes timeout or fail and to reinitialize the distributed module.

You can use torch.distributed.destroy_process_group to deinitialize and then make another call to torch.distributed.init_process_group to reinitialize. This can only work if you’re using either the Gloo or the NCCL backend, and that the underlying initialization method can be reused. I believe this is the case for the the file initialization method as well as the TCP initialization method (see https://pytorch.org/docs/stable/distributed.html#initialization for more information on both).

Good luck!

Thanks for the help! I got the idea, but how can I “let all processes timeout or fail”. How can all the processes know that there is one process meeting an exception and all the processes should destroy_process_group and reinitialize? If one process meet an exception for one minibatch, can all the processes simply just jump the current minibatch and run the next minibatch?

I find the following snippets in pytorch repo which might be helpful, but not sure how to implement the idea in detail.

def test_barrier_timeout_global(self):
        dist.destroy_process_group()

        # Explicitly pass world size to the barrier because we've
        # just destroyed any state in torch.distributed.
        self._barrier(wait_for=int(WORLD_SIZE))

        # Reinitialize global process group
        timeout = timedelta(seconds=0.2)
        dist.init_process_group(
            init_method=INIT_METHOD,
            backend=BACKEND,
            world_size=int(WORLD_SIZE),
            rank=self.rank,
            timeout=timeout,
        )
self._test_barrier_timeout(dist.group.WORLD, timeout)

(from test_distributed.py)

One way to do this is to use a smaller timeout. The default timeout for the distributed module is 30 minutes. You can override this by specifying the timeout keyword argument to init_process_group as a timedelta type (e.g. datetime.timedelta(seconds=10)). Then if one of the processes crashes, the others will time out. The problem with your proposed solution is that you’re not guaranteed that the crashed process will come back. Therefore you’ll have to rely on some out of band mechanism to figure out which processes are still alive, and only when you know for sure you have WORLD_SIZE machines (or after adjusting WORLD_SIZE), continue and reinitialize.

Is there any clean way of accomplishing this now? I’m training on images with variable sizes and every ~30k iterations there’s an OOM error. I’m having trouble understanding how to synchronize the call to init_process_group between all the processes.

We currently don’t support elasticity of workers, which means that if one of the processes crashes with an OOM error the user is currently responsible for spawning another process and re-initializing distributed communication if they want to continue with training.

You may want to consider the use of the PyTorch elastic framework: https://github.com/pytorch/elastic