Bug of torch.distributed.all_gather()

I use torch.distributed.all_gather to gather output of model from different processes:

temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
             for _ in range(utils.get_world_size())]
dist.all_gather(temp_list, batch_pred)
batch_pred = torch.cat(temp_list, dim=0)
dist.all_gather(temp_list, batch_label)
batch_label = torch.cat(temp_list, dim=0)
print(batch_pred.shape)
print(batch_label.shape)
print(batch_label)

And strange thing happened:

>>> torch.Size([3408])
>>> torch.Size([3408])
# The program is still running

batch_label can’t be printed out, and the program got stuck.

P.S. print() is set disable when not on master process.

Thanks for any feedback!

Can someone give me some clues? It seems that all_gather() triggered some problems.

Thanks for the question, @ojipadeson .

In order to further help you, could you show 1) how you are running the script, 2) the code initializing the process group, 3) the code instantiating batch_pred and batch_label?

I tried the running the following in a CPU host and it worked:

# command:
# torchrun --nproc_per_node=2 all_gather.py

import torch
import torch.distributed as dist

def main():
    dist.init_process_group('gloo')

    batch_pred = torch.ones(10, dtype=torch.int64)
    batch_label = torch.ones(10, dtype=torch.int64)

    temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
                         for _ in range(dist.get_world_size())]
    dist.all_gather(temp_list, batch_pred)
    batch_pred = torch.cat(temp_list, dim=0)
    dist.all_gather(temp_list, batch_label)
    batch_label = torch.cat(temp_list, dim=0)
    print(batch_pred.shape)
    print(batch_label.shape)
    print(batch_label)


if __name__ == '__main__':
    main()

@aazzolini Thanks for your reply.

  1. Running the script:
python -m torch.distributed.lauch --nproc_per_node=4 --distribute_train.py --my_args
  1. Initializing code:
def init_distributed_mode(args):
    if args.dist_on_itp:  # is False
        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
        os.environ['LOCAL_RANK'] = str(args.gpu)
        os.environ['RANK'] = str(args.rank)
        os.environ['WORLD_SIZE'] = str(args.world_size)
        # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: # is True
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}, gpu {}'.format(
        args.rank, args.dist_url, args.gpu), flush=True)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)

    dist.barrier()
    setup_for_distributed(args.rank == 0)
  1. batch_pred and batch_label in 1 process:
>>> torch.Size([852])
>>> torch.Size([852])

The main part of the evaluation code:

    all_labels = []
    all_predictions = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(dev_loader):
            for key in batch:
                if key == "texts" or key == "filename":
                    continue
                batch[key] = batch[key].to(device, non_blocking=True)

            attention_mask = batch["attention_mask"]
            active_label = attention_mask.view(-1) == 1
            batch_label = batch["labels"]

            batch_label = batch_label.view(-1)[active_label]
            print(batch_label.shape)  # [852]

            _, _, batch_pred = model("eval", batch)
            print(batch_pred.shape)  # [852]

            temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
                         for _ in range(utils.get_world_size())]
            dist.all_gather(temp_list, batch_pred)
            batch_pred = torch.cat(temp_list, dim=0)
            dist.all_gather(temp_list, batch_label)
            batch_label = torch.cat(temp_list, dim=0)

            print(batch_label.shape)  # [3408]
            print(batch_pred.shape)  # [3408]
            # print(batch_label)
            print(batch_pred)

            ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
            predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()

            all_predictions.extend(predictions)
            all_labels.extend(ground_truth)

model is similar to a BertForTokensClassification model in pytorch_pretrained_bert package.