All_gather() get stuck when data samples are not the same length

My all_gather function gets stuck when my data samples are not the same length(show in following code). The samples will be padded to the same length using collate_fn. What may cause the stuck and how to solve this?

It can be repeat on a multi-GPUs machine by the following code and running script.

Running configuration:

python -m torch.distributed.launch --nproc_per_node=4 debug_gather.py

Code:

import os
import random

import torch
import torch.nn as nn
from torch.utils.data import DistributedSampler, Dataset
import torch.distributed as dist

import argparse

def get_args():
    parser = argparse.ArgumentParser('script', add_help=False)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser.parse_args()

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(10001, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, 128, 2)
        self.linear = nn.Linear(128, 20)

    def forward(self, batch):
        x = batch['data']
        label = batch['labels'].view(-1)

        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)

        x = x.view(-1, 20)
        loss_fct = nn.CrossEntropyLoss()
        pred = torch.max(x, dim=1)[1]
        loss = loss_fct(x, label)
        return loss, pred

class TrainData(Dataset):
    def __init__(self):
        super(TrainData, self).__init__()
        self.data = []
        self.label = []
        for _ in range(100):
            # data_len = 256  # when data_len unchanged, not stuck
            data_len = random.randint(200, 300)
            self.data.append([random.randint(1, 10000) for _ in range(data_len)])
            self.label.append([random.randint(0, 19) for _ in range(data_len)])
    def __getitem__(self, item):
        return {'data': self.data[item], 'labels': self.label[item]}
    def __len__(self):
        return len(self.data)

def collate_wrapper(batch):
    max_seq_len = max([len(s["data"]) for s in batch])
    input_ids = torch.as_tensor([s["data"] + [0] * (max_seq_len - len(s["data"])) for s in batch], dtype=torch.long)
    labels = torch.as_tensor([s["labels"] + [0] * (max_seq_len - len(s["labels"]))
                              for s in batch], dtype=torch.long)
    samples = {"data": input_ids, "labels": labels}
    return samples

args = get_args()
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True

torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
    args.rank, args.dist_url, args.gpu), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                        world_size=args.world_size, rank=args.rank)
dist.barrier()
setup_for_distributed(args.rank == 0)

device = torch.device("cuda")

dataset = TrainData()

num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_rank = global_rank
sampler = DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank)
data_loader = torch.utils.data.DataLoader(
    dataset, sampler=sampler,
    batch_size=8,
    drop_last=True,
    collate_fn=collate_wrapper
)
model = Model()
model.to(device)

all_labels = []
all_predictions = []
model.eval()
with torch.no_grad():
    for step, batch in enumerate(data_loader):
        for key in batch:
            batch[key] = batch[key].to(device, non_blocking=True)
        batch_label = batch['labels'].view(-1)
        _, batch_pred = model(batch)

        temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
                     for _ in range(dist.get_world_size())]
        dist.all_gather(temp_list, batch_pred)
        batch_pred = torch.cat(temp_list, dim=0)
        dist.all_gather(temp_list, batch_label)
        batch_label = torch.cat(temp_list, dim=0)

        print(batch_label.shape)
        print(batch_pred.shape)
        print(batch_label)  # get stuck here, batch_label cannot be printed

        ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
        predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()

        all_predictions.extend(predictions)
        all_labels.extend(ground_truth)
print("Finished")

Thanks for any feedback!

Or: Could anybody give me a repo link that its model implements token classification with attention mask and can run distributed all_gather properly?

If the all_gather call is hanging it is probably due to mismatched shapes. You can use TORCH_DISTRIBUTED_DEBUG=DETAIL to tell you the exact shapes and ranks that are mismatched. Distributed communication package - torch.distributed — PyTorch master documentation