Pytorch ddp timeout at inference time

Pytorch ddp timeout at inference time.
Here is part of my training/testing code:

def main(configs):
    _n_gpu = int(os.environ.get("WORLD_SIZE", 0))
    _global_rank = int(os.environ.get("RANK", 0))
    _local_rank = int(os.environ.get("LOCAL_RANK", 0))
    envs = {'RANK': _global_rank, 'LRANK': _local_rank, 'nGPU': _n_gpu}
    set_random_seed(configs.seed + _global_rank)
    device = torch.device(_local_rank)

    """ init model and DDP """
    os.environ["NCCL_BLOCKING_WAIT"] = "1"
    if train_idx == 0:
        dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=900))

    model = setup_model(configs)
    model.to(device)
    optimizer = setup_e2e_optimizer(model, configs) 
    model = DistributedDataParallel(model, device_ids=([_local_rank] if configs.cpu != 1 else None), find_unused_parameters=True)

    """ init dataset """
    trainloader, test_loader = setup_dataloaders(configs)  # already initialized with dist sampler
    total_bs = _n_gpu * configs.train_batch_size
    configs.total_steps = int(math.ceil(1. * configs.num_train_epochs * len(trainloader.dataset) / total_bs))
    configs.val_steps = int(math.ceil(1. * configs.total_steps / configs.num_valid))

    """training start"""
    model.train()
    dist.barrier()
    cur_step = 0
    for _, batch in enumerate(InfiniteIterator(trainloader)):
        batch = move_to_cuda(batch, device=device)
        optimizer.zero_grad()
        logits = model(batch)

        batch["labels"] = label2onehot(batch["labels"], cls_number=configs.num_labels)
        losses = my_loss(logits, batch["labels"])

        loss = losses.mean()
        loss.backward()

        zero_none_grad(model)
        if configs.grad_norm != -1:
            grad_norm = clip_grad_norm_(model.parameters(), configs.grad_norm)

        cur_epoch = int(1. * total_bs * cur_step / len(trainloader))
        optimizer, lr_trans, lr_cnn = update_optimizer(optimizer, cur_step, cur_epoch, configs)
        optimizer.step()

        """ validation """
        if cur_step % configs.val_steps == 0 and cur_step != 0:
            print("====> GPU{} !!!!!!!! Entering testing phase !!!!!!!".format(envs['RANK']))
            dist.barrier()  # This is where the timeout occurs
            test_rst = test(model, test_loader, cur_step, envs, configs, device)
            if _global_rank == 0:
                save_model(cur_step, model.module, configs.output_dir, _global_rank)
            dist.barrier()
            
        cur_step += 1
        if cur_step >= configs.total_steps:
            break

The error is like:

====> GPU5 !!!!!!!! Entering testing phase !!!!!!!
Traceback (most recent call last):
  File "youku_cls.py", line 273, in <module>
    main(configs)
  File "youku_cls.py", line 201, in main
    dist.barrier()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2427, in barrier
    work.wait()
RuntimeError: [Rank 5] Caught collective operation timeout: WorkNCCL(OpType=ALLREDUCE, TensorShape=[1], Timeout(ms)=900000) ran for 900319 milliseconds before timing out.
====> GPU4 !!!!!!!! Entering testing phase !!!!!!!
Traceback (most recent call last):
  File "youku_cls.py", line 273, in <module>
    main(configs)
  File "youku_cls.py", line 201, in main
    dist.barrier()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2427, in barrier
    work.wait()
RuntimeError: [Rank 4] Caught collective operation timeout: WorkNCCL(OpType=ALLREDUCE, TensorShape=[1], Timeout(ms)=900000) ran for 900629 milliseconds before timing out.
====> GPU0 !!!!!!!! Entering testing phase !!!!!!!
Traceback (most recent call last):
  File "youku_cls.py", line 273, in <module>
    main(configs)
  File "youku_cls.py", line 201, in main
    dist.barrier()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2427, in barrier
    work.wait()
RuntimeError: [Rank 0] Caught collective operation timeout: WorkNCCL(OpType=ALLREDUCE, TensorShape=[1], Timeout(ms)=900000) ran for 900975 milliseconds before timing out.
====> GPU6 !!!!!!!! Entering testing phase !!!!!!!
Traceback (most recent call last):
  File "youku_cls.py", line 273, in <module>
    main(configs)
  File "youku_cls.py", line 201, in main
    dist.barrier()
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2427, in barrier
    work.wait()
RuntimeError: [Rank 6] Caught collective operation timeout: WorkNCCL(OpType=ALLREDUCE, TensorShape=[1], Timeout(ms)=900000) ran for 901022 milliseconds before timing out.

I use 8 GPUs to run the code, but there are only 4 logging

====> GPUX !!!!!!!! Entering testing phase !!!!!!!

So I figure the other 4 gpus didn’t reach the test part.
I wonder why this can happen since all 8 gpus should’ve synced at every backward step.

For the case of only single GPU is utilized to evaluate or do something right after training, DDP couldn’t wait that long time and occurred the timeout error.

I guess your code contains a part of something during every epoch or iteration.

All GPUs run the test part in the code.
And I checked the training log, all GPUs have finished the line

optimizer.step()

before the validation code.

Turns out it’s the statement if cur_step % configs.val_steps == 0 that causes the problem.
The size of dataloader differs slightly for different GPUs, leading to different configs.val_steps for different GPUs. So some GPUs jump into the if statement while others don’t.

Unify configs.val_steps for all GPUs, and the problem is solved.

I’ve never guess they have the different size of dataloader. Good work. :smile: