All_gather get stuck when using attention_mask having 0

all_gather() get stuck when there’s zero in attention_mask(show in the following code). What may cause the problem and how to solve this problem?

It can be repeated on a multi-GPUs machine with the following running script and code.

Running configuration:

python -m torch.distributed.launch --nproc_per_node=4 debug_attention.py

Code:

import os
import random

import torch
import torch.nn as nn
from torch.utils.data import DistributedSampler, Dataset
import torch.distributed as dist

import argparse


def get_args():
    parser = argparse.ArgumentParser('script', add_help=False)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser.parse_args()


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(10001, 128, padding_idx=0)
        self.lstm = nn.LSTM(128, 128, 2)
        self.linear = nn.Linear(128, 20)

    def forward(self, batch):
        x = batch['data']
        label = batch['labels'].view(-1)
        attention_mask = batch['attention_mask']

        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        active_loss = attention_mask.view(-1) == 1
        x = x.view(-1, 20)[active_loss]
        loss_fct = nn.CrossEntropyLoss()
        pred = torch.max(x, dim=1)[1]
        loss = loss_fct(x, label[active_loss])
        return loss, pred


class TrainData(Dataset):
    def __init__(self):
        super(TrainData, self).__init__()
        self.data = []
        for _ in range(100):
            data_len = random.randint(400, 500)
            input_ids = [random.randint(1, 10000) for _ in range(data_len)] + [0] * (512 - data_len)
            labels = [random.randint(0, 19) for _ in range(data_len)] + [0] * (512 - data_len)
            attention_mask = [1 if idx else 0 for idx in input_ids]
            # attention_mask = [1] * len(input_ids)  # not stuck if no 0 in attention mask
            self.data.append({"ids": input_ids, "labels": labels, "attention_mask": attention_mask})

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)


def collate_wrapper(batch):
    max_seq_len = max([len(s["ids"]) for s in batch])
    input_ids = torch.as_tensor([s["ids"] + [0] * (max_seq_len - len(s["ids"])) for s in batch], dtype=torch.long)
    labels = torch.as_tensor([s["labels"] + [0] * (max_seq_len - len(s["labels"]))
                              for s in batch], dtype=torch.long)
    attention_mask = torch.as_tensor([s["attention_mask"] + [0] * (max_seq_len - len(s["attention_mask"]))
                                      for s in batch], dtype=torch.long)
    samples = {"data": input_ids, "labels": labels, "attention_mask": attention_mask}
    return samples


args = get_args()
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True

torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
    args.rank, args.dist_url, args.gpu), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                        world_size=args.world_size, rank=args.rank)
dist.barrier()
setup_for_distributed(args.rank == 0)

device = torch.device("cuda")

dataset = TrainData()

num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_rank = global_rank
sampler = DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank)
data_loader = torch.utils.data.DataLoader(
    dataset, sampler=sampler,
    batch_size=8,
    drop_last=True,
    collate_fn=collate_wrapper
)
model = Model()
model.to(device)

all_labels = []
all_predictions = []
model.eval()
with torch.no_grad():
    for step, batch in enumerate(data_loader):
        for key in batch:
            batch[key] = batch[key].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"]  # [bts,seq_len]
        active_label = attention_mask.view(-1) == 1
        batch_label = batch['labels'].view(-1)[active_label]
        _, batch_pred = model(batch)

        temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
                     for _ in range(dist.get_world_size())]
        dist.all_gather(temp_list, batch_pred)
        batch_pred = torch.cat(temp_list, dim=0)
        dist.all_gather(temp_list, batch_label)
        batch_label = torch.cat(temp_list, dim=0)

        print(batch_label.shape)
        print(batch_pred.shape)
        print(batch_label)  # get stuck here, batch label cannot be printed

        ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
        predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()

        all_predictions.extend(predictions)
        all_labels.extend(ground_truth)
print("Finished")

Thanks for any feedback!

I gather all_predictions and all_labels using all_gather_object() instead of gather batch_label and batch_pred. And the stuck problem is solved. However, I still cannot find the cause of getting stuck