Distributed training: how to avoid loading dataset in memory N times

I would like to use DistributedSampler for MPI training of a neural net.

The problem is that, as far as I can tell, even if each process only loops over a subset of the data, the full dataset is loaded in memory once per process.

What’s the correct way to train a neural net (e.g. with SGD) on multiple processes without loading the dataset in memory N times? (I could split the data between workers manually, but that gets tricky fast because I would need to make sure that each worker processes the same exact number of batches at every epoch).


a simple reproducer of the large memory usage: run with mpirun -n N test.py, see each process occupying always the same amount of RAM.

from torch.utils.data import DistributedSampler, DataLoader, TensorDataset
import torch as to
import torch.distributed as dist
import time

if __name__ == "__main__":
  ds = TensorDataset(to.arange(100000000))
  sampler = DistributedSampler(ds)
  dl = DataLoader(ds, batch_size=2, sampler=sampler)
  for b in dl:
Partial answer, I think: it looks like DistributedSampler duplicates datapoints so that all processes can run on the same number of batches. At different epochs, DistributedSampler duplicates different datapoints, and reshuffles datapoints between processes. It would be awkward to do all this if each process didn’t have a copy of the data.