DistributedSampler and Subset() data duplication with DDP

I have a single file that contains N samples of data that I want to split into train and val subsets while using DDP. However, I am not entirely sure I am going about this correctly because I am seeing replicated training samples on multiple processes. To reproduce the issue I created a simple dataset class. Here’s my current way of doing this -

class TrivialDataset(Dataset):

def __init__(self,N):
    self.length = N

def __len__(self):
    return self.length

def __getitem__(self, index):
    return index

The above class is for illustration purposes only - this lets me print the indices retrieved on each DDP process.

dataset = TrivialDataset(20)
train_subset, val_subset = torch.utils.data.random_split(dataset, [16, 4])
train_dataset = Subset(dataset, train_subset)
val_dataset = Subset(dataset, val_subset)

I can confirm that the train and val datasets above have independent indices. Next, I create a different sampler for the train and val subsets :

seed = 10
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, seed=seed, drop_last=True)

val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, seed=seed, drop_last=True)

Finally, I create the data loaders for each subset
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)

Now if I iterate over the train_loader on 2 pytorch DDP processes (and print the indices retrieved by each train_loader), I see duplicates on the two processes. I would expect that each process would have independent indices. The expectation is that the train_subset indices and val_subset indices would be independently split across 2 processes, but that’s not what is happening. The ultimate goal is to scale up across N>>2 processes. I’m not sure where I am going wrong. If I don’t split the indices into train and val subsets, the code works as expected. I’ve searched this forum but haven’t come across a similar issue.

cc @VitalyFedyunin for DistributedSampler questions