How to have a process wait for the other with `DistributedDataParallel`?

The first time my training scripts are run, the dataset is ‘compiled’ in the Dataset class. For instance, I do work often with medical data and then it is possible I add another dataset, and I want to cut this data into tiles of a specified shape. These files are then cached to the SSD, so next time the compilation phase is skipped.

However, when I use DistributedDataParallel all processes will do this. I could check the rank of the process, and only allow rank == 0 to execute this, but then the other processes will crash because they will find an empty dataset. Is there a way I can tell the second process to wait before it starts training?

You could do one of two things:

  1. Segment your input dataset into WORLD_SIZE chunks and let every process preprocess its own subset of the dataset (this also gets you parallelization of the preprocessing). You can call torch.distributed.get_rank to get the rank of the current process and use this to index into your input dataset.
  2. Like you say, force rank == 0 to perform all preprocessing, and make all other workers wait for completion. You can do this with a call to torch.distributed.barrier. The default timeout for this call is 30 minutes. If preprocessing takes more time, you can tune the timeout through the timeout kwarg to torch.distributed.init_process_group.

Are you running on a single machine or multiple machines?

Currently I am running on one machine, but, ideally, the solution would also be useful for both single and multiple machines.

Thanks you for the suggestions, both seem reasonable, but it might be unclear beforehand how long the processing will take (also depends on the network speed and such) so, that would require a bit of tweaking for the timeout parameter. I also do not known beforehand how large the dataset will be precisely, as it can happen new samples have been added so it would be tricky to write a class which effectively splits the dataset into smaller ones, as that would require me to know how large it.

So, perhaps, a combination of 1 and 2 would also work: I can make a rough split across WORLD_SIZE, and base this on get_rank. So, it can happen that some processes are finished earlier than the others. If I then call torch.distributed.barrier() at the end of the processing of the dataset, this would have the effect that the preprocessing will be split along the processes and all will wait until each one is done with their part. Do I understand this correctly?

Yes, that’s correct. If the split is a rough split, you’ll still have to synchronize on the actual dataset size once preprocessing has completed. Distributed data parallel expects all workers to be called with an equal number of equally sized batches.

1 Like