All_gather() get stuck when data samples are not the same length

My all_gather function gets stuck when my data samples are not the same length(show in following code). The samples will be padded to the same length using collate_fn. What may cause the stuck and how to solve this?

It can be repeat on a multi-GPUs machine by the following code and running script.

Running configuration:

python -m torch.distributed.launch --nproc_per_node=4


import os
import random

import torch
import torch.nn as nn
from import DistributedSampler, Dataset
import torch.distributed as dist

import argparse

def get_args():
    parser = argparse.ArgumentParser('script', add_help=False)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser.parse_args()

def setup_for_distributed(is_master):
    This function disables printing when not in master process
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(10001, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, 128, 2)
        self.linear = nn.Linear(128, 20)

    def forward(self, batch):
        x = batch['data']
        label = batch['labels'].view(-1)

        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)

        x = x.view(-1, 20)
        loss_fct = nn.CrossEntropyLoss()
        pred = torch.max(x, dim=1)[1]
        loss = loss_fct(x, label)
        return loss, pred

class TrainData(Dataset):
    def __init__(self):
        super(TrainData, self).__init__() = []
        self.label = []
        for _ in range(100):
            # data_len = 256  # when data_len unchanged, not stuck
            data_len = random.randint(200, 300)
  [random.randint(1, 10000) for _ in range(data_len)])
            self.label.append([random.randint(0, 19) for _ in range(data_len)])
    def __getitem__(self, item):
        return {'data':[item], 'labels': self.label[item]}
    def __len__(self):
        return len(

def collate_wrapper(batch):
    max_seq_len = max([len(s["data"]) for s in batch])
    input_ids = torch.as_tensor([s["data"] + [0] * (max_seq_len - len(s["data"])) for s in batch], dtype=torch.long)
    labels = torch.as_tensor([s["labels"] + [0] * (max_seq_len - len(s["labels"]))
                              for s in batch], dtype=torch.long)
    samples = {"data": input_ids, "labels": labels}
    return samples

args = get_args()
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True

args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
    args.rank, args.dist_url, args.gpu), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                        world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)

device = torch.device("cuda")

dataset = TrainData()

num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_rank = global_rank
sampler = DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank)
data_loader =
    dataset, sampler=sampler,
model = Model()

all_labels = []
all_predictions = []
with torch.no_grad():
    for step, batch in enumerate(data_loader):
        for key in batch:
            batch[key] = batch[key].to(device, non_blocking=True)
        batch_label = batch['labels'].view(-1)
        _, batch_pred = model(batch)

        temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
                     for _ in range(dist.get_world_size())]
        dist.all_gather(temp_list, batch_pred)
        batch_pred =, dim=0)
        dist.all_gather(temp_list, batch_label)
        batch_label =, dim=0)

        print(batch_label)  # get stuck here, batch_label cannot be printed

        ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
        predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()


Thanks for any feedback!

Or: Could anybody give me a repo link that its model implements token classification with attention mask and can run distributed all_gather properly?

If the all_gather call is hanging it is probably due to mismatched shapes. You can use TORCH_DISTRIBUTED_DEBUG=DETAIL to tell you the exact shapes and ranks that are mismatched. Distributed communication package - torch.distributed — PyTorch master documentation