Dataset with randomness for DDP

In my case, the dataset is sampled from all data. Therefore, each instantiated Dataset contains different data. And, the number of samples included in each process is difference, which can cause training hang at certain epochs.

local_rank: 0, num_samples: 5024
local_rank: 1, num_samples: 5026

And here is my code generating the dataloader:

def build_data_loader(cfg, args, phase):
    sampler = None

    if cfg.DDP.ENABLE: 
        print(f"local rank {args.local_rank} successfully build {phase} dataset")    
        num_tasks = dist.get_world_size()
        global_rank = dist.get_rank()
        with torch_distributed_zero_first(args.local_rank):
            dataset = build_dataset(cfg, phase)
    else:
        dataset = build_dataset(cfg, phase)
    
    if cfg.DDP.ENABLE:
        if phase == 'train':
            sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank)
        else:
            sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank)

    data_loader = data.DataLoader(
        dataset=dataset,
        batch_size=cfg.DATA_LOADER.BATCH_SIZE,
        shuffle=True if phase == 'train' and sampler == None else False,
        num_workers=cfg.DATA_LOADER.NUM_WORKERS,
        pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
        sampler=sampler, 
        drop_last = phase =='train', 
  
    )

    return data_loader

Hey @sqiangcao. To avoid hanging, you may try to add torch.distributed.barrier at the end of each iteration or epoch to avoid imbalanced load between processes. Also, I am curious why there can be different number of samples as DistributedSampler uses padding to make the load balanced. You might wanna double check this.

@ justanhduc, thank you for your answer. I found the length of the dataset in rank 0 and rank 1 is different. This may be the reason for hanging.
So I try,

        ds_m = [None for _ in range(args.world_size)]
        dist.all_gather_object(ds_m, dataset)
        dataset = ds_m[0]

and the problem is solved. Is there a better way?