Torch DDP: Memory leak when loading datasets to memory

I have a huge numpy array of size 2.5 million x 5200 that I can load into memory for training. I am trying to train a model with 4 GPU with torch DDP but I realize that loading this naively will result in 4 duplicates of such an array being loaded into the memory. I tried to split the array into 4 equally sized smaller arrays and assigned each process to load the smaller array based on their rank number but the memory consumption rose up so quickly, way more than what would happen if I were to load the original array once and I had to terminate the training script before my machine ran out of memory.

I suspected that internally the data is being duplicated across the different processes. Does anybody know if there is any workaround to this?

Thank you!

Hi, if you have indeed assigned each process to only load a subset of the data, then memory consumption should definitely be less than having each process load the entire dataset into CPU memory.

Are you using PyTorch APIs such as DataSet / DataPipes / DistributedSampler in order to do this? A code snippet with the higher than expected memory usage would be valuable.

Hi @rvarm1

Thanks for your reply. I am using TensorDataset but did not use distributed sampler since each process should see a different subset of the data. Each file here is named after the process that should access them, i.e chunks_1.npy will be loaded by the process with rank 1 (utils.get_rank() will return the rank of the process)

train_signals = torch.Tensor(np.load(os.path.join(root_dir, "chunks_{}.npy"\
train_seqs = torch.LongTensor(np.load(os.path.join(root_dir, 
train_lengths = torch.LongTensor(np.load(os.path.join(root_dir, 
train_ds = TensorDataset(train_signals, train_seqs, train_lengths)
train_dl = DataLoader(train_ds, batch_size=batch_size_per_gpu, 
                      num_workers=args.num_workers, pin_memory=True,

I have also checked that each chunks_{}.npy.format(utils.get_rank()) is the correct partitions of the original, larger array