Hi,
I would like to use DistributedSampler for MPI training of a neural net.
The problem is that, as far as I can tell, even if each process only loops over a subset of the data, the full dataset is loaded in memory once per process.
What’s the correct way to train a neural net (e.g. with SGD) on multiple processes without loading the dataset in memory N times? (I could split the data between workers manually, but that gets tricky fast because I would need to make sure that each worker processes the same exact number of batches at every epoch).
Cheers,
Enrico
EDIT:
a simple reproducer of the large memory usage: run with mpirun -n N test.py
, see each process occupying always the same amount of RAM.
from torch.utils.data import DistributedSampler, DataLoader, TensorDataset
import torch as to
import torch.distributed as dist
import time
if __name__ == "__main__":
dist.init_process_group("mpi")
ds = TensorDataset(to.arange(100000000))
sampler = DistributedSampler(ds)
dl = DataLoader(ds, batch_size=2, sampler=sampler)
for b in dl:
time.sleep(1)