DistributedDataParallel barrier doesn't work as expected during evaluation

Hi, I am using the distributed data parallel as shown in the turtorial. I have 2 GPUs in a single machine. I want to train the model on all the nodes but evaluate the model on the node with rank 0. I set up a barrier during the evaluation. But the node rank 1 cannot go out from the barrier. My code flow is like this:

    def train( self, resume=False ):
        for i in range( self._epoch+1, self._niter+1 ):
            self._train_sampler.set_epoch(i)
            self._train()
            if self._testset is not None and i%1 == 0 :
                if not self.distributed or self.rank==0:
                    print('rank {} go to validation'.format(self.rank))
                    self._validate()
                if self.distributed:
                    print('rank {} go to barrier'.format(self.rank))
                    dist.barrier()
                    print('rank {} go out of barrier'.format(self.rank))
            self._epoch = i
            self.save_training(self._cfg.path.CHECKPOINT.format(self._epoch))
            if hasattr(self, '_scheduler'):
                self._scheduler.step()

However, the rank 1 will freeze after validation. The output of is like:
training…
start to validate rank 1 go to barrier
validating…
rank 0 go to barrier
rank 0 go out of barrier
checkpoint is saved in xxx…

Then it just freeze for the next epoch training

The validation code is like:

    def _validate( self ):
        if isinstance(self._model, DDP):
            if self.rank != 0:
                return
        print('start to validate')
        self._model.eval()

        results = []
        with torch.no_grad() :
            for idx,(inputs, targets) in enumerate(tqdm.tqdm(self._testset, 'evaluating')):
                inputs = self._set_device( inputs )
                output = self._model( inputs)
                batch_size = len(output['boxes'])

                for i in range(batch_size):
                    if len(output['boxes'][i]) == 0:
                        continue
                    # convert to xywh
                    output['boxes'][i][:,2] -= output['boxes'][i][:,0]
                    output['boxes'][i][:,3] -= output['boxes'][i][:,1]
                    for j in range(len(output['boxes'][i])):
                        results.append({'image_id':int(targets[i]['image_id']), 
                                        'category_id':output['labels'][i][j].cpu().numpy().tolist(), 
                                        'bbox':output['boxes'][i][j].cpu().numpy().tolist(), 
                                        'score':output['scores'][i][j].cpu().numpy().tolist()})

        with open('temp_result.json','w') as f:
            json.dump(results,f)
        self.eval_result(dataset=self._dataset_name) # use coco tool to evaluate the output file

If I remove the evaluation code, the barrier works as expected and the rank 1 can go out from the barrier.
Does anyone know how to solve the problem?

1 Like

Does your model forward pass contains any communication? (e.g., SyncBatchNorm layers)

BTW, how long does the validation take? Would I be correct if I assume the log shows that rank 0 finishes barrier successfully but rank 1 doesn’t?

I used the resnet_fpn backbone from torchvision and I am trying to implement faster rcnn FPN by myself. I didn’t see any SyncBatchNorm layers in torchvision resnet. So I guess there is no communication.

I just validated 15 samples just to debug, it took less than 1 mins.

Yes, the log shows that the rank 0 finishes barrier successfully but rank 1 doesn’t.

I changed the enviroment from pytorch 1.6 to latest nightly pytorch and the old problem start to change to new one. I used the same code to run the experiments, but barrier doesn’t work at all. The output is like:

Training: 0%| | 0/5647 [00:18<?, ?it/s]Average loss : 2.0772855
rank 0 go to validation
start to validate
Training: 0%| | 0/5647 [00:20<?, ?it/s]Average loss : 2.2598593
rank 1 go to barrier
rank 1 go out of barrier
Training: 0%| | 0/5647 [00:00<?, ?it/s]
rank 0 go to barrier
rank 0 go out of barrier
The checkpoint has been saved to /home/dsv/qida0163/Vision/data/embedding/frcnn_coco_person/checkpoints/checkpoints_20201011_dis_1.pkl
Epoch 2/13
Training: 0%|
Then freezing.

Which means the barrier did work since the rank 1 go out before the validation and started the new batch training before rank 1 finish the validation.

I also tried to remove the barrier before, but it just goes freezing in the end.

Just want to provide more information.

So did you sovle this problem?

I met the same problem while validating the trainning model during the epoch interval. It seems that dist.barrier() didn’t work.

Just as @mrshenli said:

rank 0 finishes barrier successfully but rank 1 doesn’t

Rank 0 finished validation and crossed the barrier while Rank 1 didn’t. And everything works fine after removing the validation code.

BTW, the validation was only on rank 0.

@Euruson Do you have a minimal repro we can try out on our end to reproduce this problem?

@pritamdamania87 Yep, I just modified this official tutorial with DDP.

The code is here:

The result is like this:

Rank:0 - Epoch 0/24
Rank:1 - Epoch 0/24
----------
Rank: 0 - train Loss: 0.5293 Acc: 0.7459
Rank: 1 - train Loss: 0.4891 Acc: 0.7623

Rank:1 - Epoch 1/24
Rank: 0 - val Loss: 0.2841 Acc: 0.8889

Rank:0 - Epoch 1/24
----------

It seems that dist.barrier() doesn’t work as rank 1 just goes to the next epoch without waiting rank 0’s validating. And then the program just freezes

Note that output to the terminal is not always guaranteed to be in the order of the actual operations. Due to things like buffering it is possible that output to stdout is in a different order from the order that the actual operations were executed in (especially in a multiprocess/multithreaded environment).

To verify this, can you add timestamps to each output line and also print something after the barrier call is done?

Sure. And the minimal repro is updated at the same time.

The result, got frozen after validation:

16:18:13 723869     | Rank:1 - Epoch 0/24
16:18:13 723866     | Rank:0 - Epoch 0/24
16:18:13 723900     | ----------
16:18:16 888776     | Rank: 0 - train Loss: 0.5663 Acc: 0.6885
16:18:16 896916     | Rank: 1 - train Loss: 0.5002 Acc: 0.7705
16:18:16 896992     | Rank:1 waiting before the barrier
16:18:17 383175     | Rank:1 left the barrier
16:18:17 383215     | Rank:1 - Epoch 1/24
16:18:17 829886     | Rank: 0 - val Loss: 0.2327 Acc: 0.9150
16:18:17 829934     | Rank:0 waiting before the barrier
16:18:17 830029     | Rank:0 left the barrier
16:18:17 830044     | Rank:0 - Epoch 1/24
16:18:17 830051     | ----------

I change the validation function to time.sleep and the barrier works fine:

16:40:57 446421     | Rank:1 - Epoch 0/24
16:40:57 446420     | Rank:0 - Epoch 0/24
16:40:57 446456     | ----------
16:41:00 635462     | Rank:1 - train Loss: 0.5516 Acc: 0.6885
16:41:00 635536     | Rank:1 waiting before the barrier
16:41:00 635599     | Rank:0 - train Loss: 0.4810 Acc: 0.7705
16:41:00 635663     | Rank:0 sleeping
16:41:05 640713     | Rank:0 awaken
16:41:05 640734     | Rank:0 waiting before the barrier
16:41:05 640875     | Rank:0 left the barrier
16:41:05 640890     | Rank:0 - Epoch 1/24
16:41:05 640882     | Rank:1 left the barrier
16:41:05 640912     | ----------
16:41:05 640935     | Rank:1 - Epoch 1/24
16:41:08 641714     | Rank:1 - train Loss: 0.3519 Acc: 0.8279
16:41:08 641790     | Rank:1 waiting before the barrier
16:41:08 651248     | Rank:0 - train Loss: 0.4229 Acc: 0.8156
16:41:08 651340     | Rank:0 sleeping
16:41:13 656394     | Rank:0 awaken

@pritamdamania87 @mrshenli Any ideas to solve this problem?

@Euruson I think I’ve figured out the problem here. You are still using DDP for the validation phase even though it runs only on one rank. Even though you might not run the backward pass for DDP during eval phase, the forward pass for DDP might still invoke some collective operations (ex: syncing buffers or syncing indices when it rebuilts buckets the first time). As a result, what is happening is that your collective ops are mismatched and some of the collective ops for DDP’s forward pass on rank 0 match up with the barrier() call on rank 1 leading it to leave the barrier.

If you make the following code change, your script seems to be working as expected:

if phase == "val":
  outputs = model.module(inputs)
else:
  outputs = model(inputs)

model.module retrieves the underlying non-replicated model which you can use for validation. The output on my local machine is as follows with this change:

19:39:05 071604     | Rank:0 - Epoch 0/24
19:39:05 071607     | Rank:1 - Epoch 0/24
19:39:05 071672     | ----------
19:39:08 620338     | Rank: 1 - train Loss: 0.4468 Acc: 0.7787
19:39:08 620479     | Rank:1 waiting before the barrier
19:39:08 651507     | Rank: 0 - train Loss: 0.5222 Acc: 0.7623
19:39:10 524626     | Rank: 0 - val Loss: 0.2312 Acc: 0.9281
19:39:10 524726     | Rank:0 waiting before the barrier
19:39:10 524973     | Rank:0 left the barrier
19:39:10 524994     | Rank:1 left the barrier
19:39:10 525106     | Rank:1 - Epoch 1/24
19:39:10 525123     | Rank:0 - Epoch 1/24
19:39:10 525156     | ----------
19:39:13 735254     | Rank: 1 - train Loss: 0.3994 Acc: 0.8197
19:39:13 735366     | Rank:1 waiting before the barrier
19:39:13 739752     | Rank: 0 - train Loss: 0.4128 Acc: 0.8197
19:39:15 298398     | Rank: 0 - val Loss: 0.2100 Acc: 0.9216
19:39:15 298483     | Rank:0 waiting before the barrier
19:39:15 298672     | Rank:0 left the barrier
19:39:15 298702     | Rank:0 - Epoch 2/24
19:39:15 298716     | ----------
19:39:15 298728     | Rank:1 left the barrier
19:39:15 298811     | Rank:1 - Epoch 2/24
19:39:18 586375     | Rank: 0 - train Loss: 0.4336 Acc: 0.8156
19:39:18 605651     | Rank: 1 - train Loss: 0.3094 Acc: 0.8893
19:39:18 605791     | Rank:1 waiting before the barrier
19:39:20 199963     | Rank: 0 - val Loss: 0.2205 Acc: 0.9216
19:39:20 200061     | Rank:0 waiting before the barrier
19:39:20 200296     | Rank:0 left the barrier
19:39:20 200329     | Rank:0 - Epoch 3/24
7 Likes

@pritamdamania87 It works! Thanks a lot!

Any references for more details? The tutorial on the official website just mentions that the backward would trigger the barrier.

I don’t think this is documented, but under certain conditions there might be a sync during the forward pass: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L675. In addition to this, we rebuild buckets once: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L671, which triggers a sync across ranks: https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/reducer.cpp#L1377

1 Like

It works for me. Thanks a lot!!!

Is there any documentation discussing the above fix? dist.barrier()'s behaviour is a bit mysterious but this seems to fix most issues related to silent hanging. Thanks again, this fix worked for me.

Thanks a lot for the great answer and detailed explanation!