I would like to use DistributedSampler for MPI training of a neural net.
The problem is that, as far as I can tell, even if each process only loops over a subset of the data, the full dataset is loaded in memory once per process.
What’s the correct way to train a neural net (e.g. with SGD) on multiple processes without loading the dataset in memory N times? (I could split the data between workers manually, but that gets tricky fast because I would need to make sure that each worker processes the same exact number of batches at every epoch).
a simple reproducer of the large memory usage: run with
mpirun -n N test.py, see each process occupying always the same amount of RAM.
from torch.utils.data import DistributedSampler, DataLoader, TensorDataset import torch as to import torch.distributed as dist import time if __name__ == "__main__": dist.init_process_group("mpi") ds = TensorDataset(to.arange(100000000)) sampler = DistributedSampler(ds) dl = DataLoader(ds, batch_size=2, sampler=sampler) for b in dl: time.sleep(1)