all_gather()
get stuck when there’s zero in attention_mask
(show in the following code). What may cause the problem and how to solve this problem?
It can be repeated on a multi-GPUs machine with the following running script and code.
Running configuration:
python -m torch.distributed.launch --nproc_per_node=4 debug_attention.py
Code:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DistributedSampler, Dataset
import torch.distributed as dist
import argparse
def get_args():
parser = argparse.ArgumentParser('script', add_help=False)
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser.parse_args()
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embedding = nn.Embedding(10001, 128, padding_idx=0)
self.lstm = nn.LSTM(128, 128, 2)
self.linear = nn.Linear(128, 20)
def forward(self, batch):
x = batch['data']
label = batch['labels'].view(-1)
attention_mask = batch['attention_mask']
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.linear(x)
active_loss = attention_mask.view(-1) == 1
x = x.view(-1, 20)[active_loss]
loss_fct = nn.CrossEntropyLoss()
pred = torch.max(x, dim=1)[1]
loss = loss_fct(x, label[active_loss])
return loss, pred
class TrainData(Dataset):
def __init__(self):
super(TrainData, self).__init__()
self.data = []
for _ in range(100):
data_len = random.randint(400, 500)
input_ids = [random.randint(1, 10000) for _ in range(data_len)] + [0] * (512 - data_len)
labels = [random.randint(0, 19) for _ in range(data_len)] + [0] * (512 - data_len)
attention_mask = [1 if idx else 0 for idx in input_ids]
# attention_mask = [1] * len(input_ids) # not stuck if no 0 in attention mask
self.data.append({"ids": input_ids, "labels": labels, "attention_mask": attention_mask})
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
def collate_wrapper(batch):
max_seq_len = max([len(s["ids"]) for s in batch])
input_ids = torch.as_tensor([s["ids"] + [0] * (max_seq_len - len(s["ids"])) for s in batch], dtype=torch.long)
labels = torch.as_tensor([s["labels"] + [0] * (max_seq_len - len(s["labels"]))
for s in batch], dtype=torch.long)
attention_mask = torch.as_tensor([s["attention_mask"] + [0] * (max_seq_len - len(s["attention_mask"]))
for s in batch], dtype=torch.long)
samples = {"data": input_ids, "labels": labels, "attention_mask": attention_mask}
return samples
args = get_args()
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
dist.barrier()
setup_for_distributed(args.rank == 0)
device = torch.device("cuda")
dataset = TrainData()
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_rank = global_rank
sampler = DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank)
data_loader = torch.utils.data.DataLoader(
dataset, sampler=sampler,
batch_size=8,
drop_last=True,
collate_fn=collate_wrapper
)
model = Model()
model.to(device)
all_labels = []
all_predictions = []
model.eval()
with torch.no_grad():
for step, batch in enumerate(data_loader):
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
attention_mask = batch["attention_mask"] # [bts,seq_len]
active_label = attention_mask.view(-1) == 1
batch_label = batch['labels'].view(-1)[active_label]
_, batch_pred = model(batch)
temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
for _ in range(dist.get_world_size())]
dist.all_gather(temp_list, batch_pred)
batch_pred = torch.cat(temp_list, dim=0)
dist.all_gather(temp_list, batch_label)
batch_label = torch.cat(temp_list, dim=0)
print(batch_label.shape)
print(batch_pred.shape)
print(batch_label) # get stuck here, batch label cannot be printed
ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()
all_predictions.extend(predictions)
all_labels.extend(ground_truth)
print("Finished")
Thanks for any feedback!