In my case, the dataset is sampled from all data. Therefore, each instantiated Dataset contains different data. And, the number of samples included in each process is difference, which can cause training hang at certain epochs.
local_rank: 0, num_samples: 5024
local_rank: 1, num_samples: 5026
And here is my code generating the dataloader:
def build_data_loader(cfg, args, phase):
sampler = None
if cfg.DDP.ENABLE:
print(f"local rank {args.local_rank} successfully build {phase} dataset")
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
with torch_distributed_zero_first(args.local_rank):
dataset = build_dataset(cfg, phase)
else:
dataset = build_dataset(cfg, phase)
if cfg.DDP.ENABLE:
if phase == 'train':
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank)
else:
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank)
data_loader = data.DataLoader(
dataset=dataset,
batch_size=cfg.DATA_LOADER.BATCH_SIZE,
shuffle=True if phase == 'train' and sampler == None else False,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
sampler=sampler,
drop_last = phase =='train',
)
return data_loader