Hello, I am wondering how these lines of code exactly work ,
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler)
if the world_size = 8 having 8 gpus, so this dataloader takes 8 batches at 1 time, giving each GPU 1 distinct batch , like for example we have 16 batches [0,1,2,3,4,5,6,7,8,9…], so GPU 1 take batch 0 , GPU 2 take batch 1, GPU 3 take batch 2 and so on (assuming no shuffling). If not can someone tell me how does it work exactly ?