Parallel Training loading all batches at once instead of batchwise

I have a dataset that I’m feeding into a NN, but the dataset is larger than the available memory on my machine (training on single CPU with 24 cores and 32 GB memory). I’m trying to load the data in batches and train in parallel, but for some reason all of the memory for all of the batches is being used at once, and the system crashes (If my total dataset is 40 GB, instead of loading (40GB/nbatches) * nprocesses, it loads 40GB at once).

I’ve tried various implementations of DDP, torch.multiprocessing, torch.distributed, and DataParallel, but I haven’t been able to figure it out.

The code is essentially:

class Dataset(
  'Characterizes a dataset for PyTorch'
  def __init__(self):
        self.srcDir = os.getcwd() + '/batches' = os.listdir(self.srcDir)

  def __len__(self):
        return len(

  def __getitem__(self, index):

        fname = os.path.join(self.srcDir,
        with open(fname, 'rb') as f1:
            X = pickle.load(f1)
def parallelFit():
   from torch.nn.parallel import DistributedDataParallel as DDP
   train_dataset = Dataset()
   train_sampler =
   batches = torch.utils.DataLoader(dataset=train_dataset, sampler=train_sampler)

   ddp_model = DDP(self.model, device_ids=[])
   optimizer = LBFGS(ddp_model.parameters())

  ###I see the same issue if I don't use DDP and if I don't use DistributedSampler

Then everything is run with:

processes = []
for rank in range(num_processes):
   p = Process(target=calc.parallelFit, args=(partitions.use(rank)))  
for p in processes:

I’ve been reading through the forums and the documentation, but I feel like I’m just missing something.Preformatted text

Based on your __getitem__ it seems like you have pre-batched the data and now loading the files individually. The pickle.load() is a little suspicious because it might use a lot of memory. Have you tried reading from one source and setting the batch_size parameter on torch.utils.DataLoader instead? Also it would be helpful to see how you are using the data loader during the training loop.

Thanks for the reply! The data is pre-batched because setting up our training data is pretty complex and requires some pre-processing. Mainly there was no easy way to save our input to one file without running out of memory. For smaller training sets I’ve tried saving everything to one file and then loading with different batch sizes in DataLoader. I seem to get the same behavior either way. As for pickle.load(), it does have some excess memory usage, but the excess memory seems to always be a constant factor.

After loading the batches as shown above, the training loop is more or less:

inside of parallelFit():

batches = torch.utils.DataLoader(dataset=train_dataset, sampler=train_sampler)
optimizer = LBFGSScipy(ddp_model.parameters(), max_iter=maxEpochs, logger=logger, rank=rank
         def closure():
              loss = 0
              energyloss = 0
              forceloss = 0
              energyRMSE = 0
              forceRMSE = 0
              batchid = 0
              for batch in batches:
                epoch = 0
                #print('allFPs', batch.allElement_fps)
                predEnergies, predForces = self.model(batch.allElement_fps, batch.dgdx, batch)
                loss += criterion(predEnergies, predForces, batch.energies, batch.forces,
                                 natomsEnergy = batch.natomsPerimageEnergy,
                                 natomsForce = batch.natomsPerimageForce)
                lossgrads = torch.autograd.grad(loss, self.model.parameters(),
                                                retain_graph = True, create_graph=False)
                for p, g in zip(self.model.parameters(), lossgrads):
                    #if batchid == 0:
                    p.grad = g
                    #    p.grad += g
                batchid += 1
                energyloss += criterion.energyloss
                forceloss += criterion.forceloss
                #if parallel:
                #    average_gradients(self.model)
                #    dist.all_reduce(loss, dist.ReduceOp.SUM)
                #    dist.all_reduce(energyloss, dist.ReduceOp.SUM)
                #    dist.all_reduce(forceloss, dist.ReduceOp.SUM)
                #if rank == 0:
      '%s', "{:12d} {:12.8f} {:12.8f} {:12.8f}".format(epoch, loss.item(), energyRMSE, forceRMSE))
                if epoch % self.logmodel_interval == 0:

#              if parallel:
#                  average_gradients(self.model)
#                  dist.all_reduce(loss, dist.ReduceOp.SUM)
#                  dist.all_reduce(energyloss, dist.ReduceOp.SUM)
#                  dist.all_reduce(forceloss, dist.ReduceOp.SUM)
              energyRMSE = np.sqrt(energyloss.item()/self.nimages)
              forceRMSE = np.sqrt(forceloss.item()/self.nimages)
              if energyRMSE < self.energyRMSEtol and forceRMSE < self.forceRMSEtol:
         'Minimization converged')
                   io.saveFF(self.model, self.preprocessParas, filename="mlff.pyamff")
              return loss, energyRMSE, forceRMSE

I left in some extra code that’s commented out. They’re just different ways i’ve tried to approach the problem. Left it in in case I was on the right path with them.

Let me know if I should show any additional parts of the code.