Synchronizing/pausing all processes but one with DistributedDataParallel

Hello,
I’m working with the code example from the official ImagNet distributed tutorial.
Basically, the code uses torch.multiprocessing.spawn(main_worker) to run a copy of the “main worker” function on each GPU.
Then, in each worker the dist.init_process_group command is run and both the model and dataset/dataloader are created and cast into torch.nn.parallel.DistributedDataParallel(model) and torch.utils.data.distributed.DistributedSampler(dataset).

My problem is that after every epoch I want to modify and rewrite all the data in the dataset, I thought the most straightforward way was to run this shuffling in only one of the nodes inside of an “if” like if args.multiprocessing_distributed and args.rank == 1: so not all the nodes would perform the shuffling simultaneously, similarly to the 252 line of the code.

The operation of rewriting the dataset takes long time (>10 minutes). This created a problem where processes which didn’t perform the shuffling were trying to access data while it’s being rewritten by the node that modifies the data.

Is there any suggested way of making the rest of the processes wait for the one process to finish the modification and writing of the data before continuing the training?

I found this “barrier” method from the pytorch distributed package.
And also this section about synchronization in python’s documentation.
But because the error I get isn’t 100% reproducible (in some runs it appears 10 minutes after start in some runs it appears hours into the run) I can’t really test the those implementations and I couldn’t find any easy to follow examples.

Any suggestions would be appreciated :sweat_smile:

dist.barrier() should help to block all processes in the group until everyone has reached the same barrier. What error did you see after adding a barrier?