All rank reach distributed.barrier() but no one pass it

Hi, I’m in trouble with distributed.barrier(), I use this to let other ranks wait for rank0 to do test and save parm, when using DDP training all rank share param so I think no need to use all ranks to do test and save.
Some code here.

    distributed.barrier()    # first barrier
    for epoch in range(resume_epoch, epochs):
        tic = time.time()
        if not cfg.data.transform.dali_pipe:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, train_loader, Loss, optimizer, epoch, lr_scheduler, logger, (top1_acc, loss_record, *train_dc),
                        scaler, gpu, args, cfg)
        if is_first_rank:
            one_epoch_time_cost = int(time.time() - tic)
            train_speed = cfg.data.num_training_samples // one_epoch_time_cost
            train_time_cost = "%02d:%02d:%02d" % seconds_to_time(one_epoch_time_cost)
            logger.info(f'Finish one epoch cost {train_time_cost}, speed: {train_speed} samples/s.')
            if not cfg.test.no_test:
                test(model, val_loader, Loss, epoch, logger, (top1_acc, top5_acc, loss_record), gpu)
            acc = top1_acc.get()
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            }
            torch.save(checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.out_path, cfg.model.network, epoch, acc))
            if acc > best_top1_acc:
                old_backbone = '{}/{}_backbone_{:.5}.pth'.format(args.out_path, cfg.model.network, best_top1_acc)
                if os.path.exists(old_backbone):
                    os.remove(old_backbone)
                best_top1_acc = acc
                torch.save(checkpoint['model'], '{}/{}_backbone_{:.5}.pth'.format(args.out_path, cfg.model.network, acc))

        if cfg.data.transform.dali_pipe.enable:
            train_loader.reset()
   
        logger.info(f"rank:{gpu} got here.")
        distributed.barrier()
        logger.info(f"rank:{gpu} pass here.")

My issue is all rank could pass first barrier, and all rank could get second barrier but no one pass it.
Could you please give me some advice?

Would it be possible to come up with a minimal script that reproduces the issue that we can run to reproduce it on our end? Any logs you may also have would be helpful.

Also - which distributed comm. backend are you using (Gloo, NCCL, MPI)? If all ranks indeed call into the second barrier but none make it out, I’m guessing there is likely a hang or some other form of desynchronization going on.

Hi, @rvarm1 , I could give you full log of this, but minimal script may need some time to prepare.
And I use NCCL backend.

Use GPU: 2 for training
Use GPU: 5 for training
Use GPU: 4 for training
Use GPU: 6 for training
Use GPU: 0 for training
Use GPU: 3 for training
Use GPU: 7 for training
Use GPU: 1 for training
Namespace(data=(classes=1000; num_training_samples=1281167; input_size=224; transform=(type=normal; color_jit=0.4; dali_pipe=(enable=True; dali_cpu=False)); dataloader=(num_workers=40; sampler=distributed_sampler)); model=(network=GhostNetRE; model_setting=width=0.5,dropout=0.1; model_info=True); trainer=(epochs=360; batch_size=256; dtype=float16; mix_precision_training=True; lr_scheduler=(type=cosine; warmup_lr=0; warmup_epochs=5); optimizer=(learning_rate=2.6; momentum=0.9; weight_decay=3e-05; nesterov=True)); test=(no_test=False; crop_ratio=0.875); tricks=(label_smoothing=(enable=True; smoothing=0.1); no_weight_decay=True; lookahead=False; mixup=(enable=False; alpha=0.2); last_gamma=False; sgd_gc=False); logs=(tensorboard=True; log_interval=200; logging_file_name=distribute_train_imagenet.log); resume=(resume_epoch=0; resume_param=None))
Train with FP16.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Reducer buckets have been rebuilt in this iteration.
Epoch 0, Node 0, GPU 7, Iter 200, Top1 Accuracy:0.0029928, Loss:6.847, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 3, Iter 200, Top1 Accuracy:0.0031095, Loss:6.8501, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 2, Iter 200, Top1 Accuracy:0.0022932, Loss:6.8461, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 6, Iter 200, Top1 Accuracy:0.0026625, Loss:6.8474, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 0, Iter 200, Top1 Accuracy:0.0027596, Loss:6.8466, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 5, Iter 200, Top1 Accuracy:0.0026819, Loss:6.8472, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 4, Iter 200, Top1 Accuracy:0.0026819, Loss:6.8492, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 1, Iter 200, Top1 Accuracy:0.0023515, Loss:6.8483, 622 samples/s. lr: 0.16806.
Epoch 0, Node 0, GPU 4, Iter 400, Top1 Accuracy:0.012313, Loss:6.5364, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 3, Iter 400, Top1 Accuracy:0.012235, Loss:6.5393, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 6, Iter 400, Top1 Accuracy:0.012293, Loss:6.5403, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 1, Iter 400, Top1 Accuracy:0.012595, Loss:6.5376, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 5, Iter 400, Top1 Accuracy:0.012089, Loss:6.5356, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 7, Iter 400, Top1 Accuracy:0.012118, Loss:6.5386, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 0, Iter 400, Top1 Accuracy:0.012313, Loss:6.5372, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 2, Iter 400, Top1 Accuracy:0.01169, Loss:6.5363, 587 samples/s. lr: 0.33446.
Epoch 0, Node 0, GPU 2, Iter 600, Top1 Accuracy:0.027331, Loss:6.2461, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 4, Iter 600, Top1 Accuracy:0.02813, Loss:6.2486, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 5, Iter 600, Top1 Accuracy:0.027961, Loss:6.2488, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 6, Iter 600, Top1 Accuracy:0.027968, Loss:6.2517, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 1, Iter 600, Top1 Accuracy:0.028065, Loss:6.2518, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 7, Iter 600, Top1 Accuracy:0.027786, Loss:6.2526, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 3, Iter 600, Top1 Accuracy:0.027428, Loss:6.2534, 567 samples/s. lr: 0.50086.
Epoch 0, Node 0, GPU 0, Iter 600, Top1 Accuracy:0.027636, Loss:6.2534, 567 samples/s. lr: 0.50086.
Finish one epoch cost 00:04:43, speed: 4527 samples/s.
rank:3 get second barrier.
rank:4 get second barrier.
rank:1 get second barrier.
rank:6 get second barrier.
rank:7 get second barrier.
rank:5 get second barrier.
rank:2 get second barrier.
Test Epoch 0, Top1 Accuracy:0.08558, Top5 Accuracy:0.22848, Loss:5.3841
rank:0 get second barrier.

I see only one log line for Test Epoch 0, Top1 Accuracy:0.08558, Top5 Accuracy:0.22848, Loss:5.3841 , is it possible that rank 0 is issuing some additional collective comm. as part of this (e.g. all_reduce) that isn’t matched by the other ranks? That may explain the hang.

Hi, @rvarm1 .
I don’t use any distributed op in test, below is whole code of test. Only one line because I only use rank 0 to do test.

@torch.no_grad()
def test(model, val_loader, criterion, epoch, logger, collectors, gpu):
    top1_acc, top5_acc, loss_record = collectors
    top1_acc.reset()
    top5_acc.reset()
    loss_record.reset()

    model.eval()
    for meta_data in val_loader:
        data = meta_data[0].cuda(gpu, non_blocking=True)
        labels = meta_data[1].cuda(gpu, non_blocking=True)

        outputs = model(data)
        losses = criterion(outputs, labels)

        top1_acc.update(outputs, labels)
        top5_acc.update(outputs, labels)
        loss_record.update(losses)

    test_msg = 'Test Epoch {}, {}:{:.5}, {}:{:.5}, {}:{:.5}'.format(epoch, top1_acc.name, top1_acc.get(), top5_acc.name,
                                                                    top5_acc.get(), loss_record.name, loss_record.get())
    logger.info(test_msg)
    top1_acc.write_tb(name='Val Top1 Accuracy', iteration=epoch)
    top5_acc.write_tb(name='Val Top5 Accuracy', iteration=epoch)
    loss_record.write_tb(name='Val Loss', iteration=epoch)

And if I use any distributed op I should not pass that until all ranks do this op right?