Question about the behavior of torch.utils.data.distributed.DistributedSampler

Hi,

I am currently playing around distributed training, and I have some trouble trying to understand it:

When using torch.utils.data.distributed.DistributedSampler, does it mean to split the whole dataset D to D1 and D2(considering the one node two gpus scenario), and each process only iterates data from one of D1 or D2, thus there will never be a batch generated from a mixed of D1 data and D2 data? Or will the split operation be committed automatically each time one epoch is done and another epoch is started?

I have a similar question here. If this is the way the DistributedSampler sample data, and if the portion of the data is quite smal on each single GPU, the network on each single GPU may overfit trained on the small portion of the dataset.

Will the portion of the dataset on each single GPU update after each epoch?

Yes, l also have a question about it. Is there someone have any insight for this?

Hi, I have found that your question can be solved by sampler.set_epoch(e). Just try my code, each gpu will get shuffled and different data in a different epoch.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler


torch.distributed.init_process_group(backend="nccl")

input_size = 5
output_size = 2
batch_size = 2
data_size = 16

local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)

class RandomDataset(Dataset):
    def __init__(self, size, length, local_rank):
        self.len = length
        self.data = torch.stack([torch.ones(5), torch.ones(5)*2,
                                 torch.ones(5)*3,torch.ones(5)*4,
                                 torch.ones(5)*5,torch.ones(5)*6,
                                 torch.ones(5)*7,torch.ones(5)*8,
                                 torch.ones(5)*9, torch.ones(5)*10,
                                 torch.ones(5)*11,torch.ones(5)*12,
                                 torch.ones(5)*13,torch.ones(5)*14,
                                 torch.ones(5)*15,torch.ones(5)*16]).to('cuda')

        self.local_rank = local_rank
    def __getitem__(self, index):

        return self.data[index]

    def __len__(self):
        return self.len
    
dataset = RandomDataset(input_size, data_size, local_rank)
sampler = DistributedSampler(dataset)
rand_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         sampler=sampler)

e = 0
while e < 2:
    t = 0
    sampler.set_epoch(e)
    for data in rand_loader:
        print(data)
    e+=1