Distributed training: how to avoid loading dataset in memory N times

I would like to use DistributedSampler for MPI training of a neural net.

The problem is that, as far as I can tell, even if each process only loops over a subset of the data, the full dataset is loaded in memory once per process.

What’s the correct way to train a neural net (e.g. with SGD) on multiple processes without loading the dataset in memory N times? (I could split the data between workers manually, but that gets tricky fast because I would need to make sure that each worker processes the same exact number of batches at every epoch).


a simple reproducer of the large memory usage: run with mpirun -n N test.py, see each process occupying always the same amount of RAM.

from torch.utils.data import DistributedSampler, DataLoader, TensorDataset
import torch as to
import torch.distributed as dist
import time

if __name__ == "__main__":
  ds = TensorDataset(to.arange(100000000))
  sampler = DistributedSampler(ds)
  dl = DataLoader(ds, batch_size=2, sampler=sampler)
  for b in dl:
1 Like

Partial answer, I think: it looks like DistributedSampler duplicates datapoints so that all processes can run on the same number of batches. At different epochs, DistributedSampler duplicates different datapoints, and reshuffles datapoints between processes. It would be awkward to do all this if each process didn’t have a copy of the data.

Hi @bluehood

Do you have any additional insight into this? I’m running into this problem as well.

@ptrblck and @fduwjj – is there any way to get around this? I have 8 GPUs on a single node, and each GPU has 12.3GB of memory. The Linux machine I’m using has a CPU with 128GB memory. My dataset is upwards of 30GB, so making 8 copies of this is intractable, unless I lazy load the data, which would add a non-trivial amount of pre-processing time for every batch.

Any help is much appreciated.


Lazily loading the dataset is the common approach as you wouldn’t need a lot of changes in your code and coulf just use the mentioned Distributed Sampler. If that’s not possible try to use torch.multiprocessing to share the data as described here.

Thanks @ptrblck

I’m currently using torch.distributed as opposed to torch.multiprocessing (see below for template code). Would using torch.multiprocessing as opposed to torch.distributed solve having to make copies of the full dataset / is there a preference for using one over the other?

def train(local_world_size, local_rank, args):

    # GPU setup
    n = torch.cuda.device_count() // local_world_size
    device_ids = list(range(local_rank * n, (local_rank + 1) * n))

    dataset = Dataset1(root=args.root_dir,
                       in_memory=False,               # load data up front (don't lazy load)
                       n_samples=args.n_samples, ...) # only load subset of data

    # set up data samplers
    sampler = DistributedSampler(dataset, rank=local_rank, shuffle=True, seed=0)

    # set up dataloaders
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=1, sampler=sampler)

def main(local_world_size, local_rank, args):
    # These are the parameters used to initialize the process group
    env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
        f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
        + f"rank = {dist.get_rank()}, backend={dist.get_backend()}"

    # train your model
    train(args.local_world_size, args.local_rank, args)

    # Tear down the process group

Using this approach above results in making n-devices copies of the loaded dataset. Within the Dataset1 class, the loaded tensors are already stored on CPU using the data.cpu() method, both when loading up front and when lazy loading – all data is then transferred to the GPU devices during batching.