My all_gather
function gets stuck when my data samples are not the same length(show in following code). The samples will be padded to the same length using collate_fn
. What may cause the stuck and how to solve this?
It can be repeat on a multi-GPUs machine by the following code and running script.
Running configuration:
python -m torch.distributed.launch --nproc_per_node=4 debug_gather.py
Code:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DistributedSampler, Dataset
import torch.distributed as dist
import argparse
def get_args():
parser = argparse.ArgumentParser('script', add_help=False)
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser.parse_args()
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embedding = nn.Embedding(10001, 128, padding_idx=0)
self.lstm = nn.LSTM(128, 128, 2)
self.linear = nn.Linear(128, 20)
def forward(self, batch):
x = batch['data']
label = batch['labels'].view(-1)
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.linear(x)
x = x.view(-1, 20)
loss_fct = nn.CrossEntropyLoss()
pred = torch.max(x, dim=1)[1]
loss = loss_fct(x, label)
return loss, pred
class TrainData(Dataset):
def __init__(self):
super(TrainData, self).__init__()
self.data = []
self.label = []
for _ in range(100):
# data_len = 256 # when data_len unchanged, not stuck
data_len = random.randint(200, 300)
self.data.append([random.randint(1, 10000) for _ in range(data_len)])
self.label.append([random.randint(0, 19) for _ in range(data_len)])
def __getitem__(self, item):
return {'data': self.data[item], 'labels': self.label[item]}
def __len__(self):
return len(self.data)
def collate_wrapper(batch):
max_seq_len = max([len(s["data"]) for s in batch])
input_ids = torch.as_tensor([s["data"] + [0] * (max_seq_len - len(s["data"])) for s in batch], dtype=torch.long)
labels = torch.as_tensor([s["labels"] + [0] * (max_seq_len - len(s["labels"]))
for s in batch], dtype=torch.long)
samples = {"data": input_ids, "labels": labels}
return samples
args = get_args()
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
dist.barrier()
setup_for_distributed(args.rank == 0)
device = torch.device("cuda")
dataset = TrainData()
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_rank = global_rank
sampler = DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank)
data_loader = torch.utils.data.DataLoader(
dataset, sampler=sampler,
batch_size=8,
drop_last=True,
collate_fn=collate_wrapper
)
model = Model()
model.to(device)
all_labels = []
all_predictions = []
model.eval()
with torch.no_grad():
for step, batch in enumerate(data_loader):
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
batch_label = batch['labels'].view(-1)
_, batch_pred = model(batch)
temp_list = [torch.zeros_like(batch_pred, dtype=torch.int64)
for _ in range(dist.get_world_size())]
dist.all_gather(temp_list, batch_pred)
batch_pred = torch.cat(temp_list, dim=0)
dist.all_gather(temp_list, batch_label)
batch_label = torch.cat(temp_list, dim=0)
print(batch_label.shape)
print(batch_pred.shape)
print(batch_label) # get stuck here, batch_label cannot be printed
ground_truth = batch_label[batch_label != 0].cpu().numpy().tolist()
predictions = batch_pred[batch_label != 0].cpu().numpy().tolist()
all_predictions.extend(predictions)
all_labels.extend(ground_truth)
print("Finished")
Thanks for any feedback!