How do I copy data to GPU in parallel?

I have an image dataset that doesn’t fit in memory. I want to read minibatches off disk, copy them to GPU and train a model on them.

PyTorch’s DataLoader has been very helpful in hiding the cost of loading the minibatch with multithreading, but copying to the GPU is still sequential

I’m trying to pipeline my training loop such that copying data to the GPU happens in parallel with the rest (forward pass, backward backprop, etc) (something like this).

I summarise what I have tried so far below, but I think I’m going far down the wrong path with it. What is the correct way to do this in pyTorch?


What I’ve tried:

so far my training loop looks like this:

trainset = MemmapDataset("dataset.npy")
trainloader = DataLoader(trainset, batch_size=param.batch_size,
                             shuffle=True, num_workers=4,pin_memory=True)
for i in range(epochs):
    for features,labels in train_loader:
        features = features.to(device)
        labels = labels.to(device)

       predictions = neural_net(features)
        ... rest of training ...

I want to change that to(device) operation such that it runs concurrently with the rest of training; the copy operation should run asynchronously in on one of the cuda streams while the other training kernels run , instead of right now, where all the training stops, and the GPU is essentially idle just copying data

I tried making to(device) one of the transforms that runs in my dataset class’ __getitem__ method as so:

class ToDeviceTransform:
    def __init__(self, device):
        self.device = device

    def __call__(self, data: torch.Tensor):
        return data.contiguous().to(self.device)

and this works if num_workers=0 and pinned_memory=False in the dataloader, but this is still just a sequential copy, and its slower because now I lose the parallel load from disk.

setting num_workers >0 throws the following runtime error at my ToDeviceTransform.call() hack:
------------- | Cut long traceback above | -------------------

File “/media/ihexx/Shared_Partition/projects/ProjectBoom/CarlaEnv/vae/datasets.py”, line 42, in call
return data.contiguous().to(self.device)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/site-packages/torch/cuda/init.py”, line 163, in _lazy_init
torch._C._cuda_init()
RuntimeError: cuda runtime error (3) : initialization error at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/THCGeneral.cpp:55


I then tried setting mp.set_start_method('spawn') in the main thread before creating the dataloader object, but that throws a MemoryError, which I assume is because the processes creating the cuda tensors are exiting and releasing memory? Haven’t been able to get past this.

Here’s the traceback:
File “/media/ihexx/Shared_Partition/projects/ProjectBoom/CarlaEnv/vae/train_offline.py”, line 53, in train
dataiter = iter(trainloader)

File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/site-packages/torch/utils/data/dataloader.py”, line 193, in iter
return _DataLoaderIter(self)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/site-packages/torch/utils/data/dataloader.py”, line 469, in init
w.start()
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/process.py”, line 105, in start
self._popen = self._Popen(self)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/context.py”, line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/context.py”, line 284, in _Popen
return Popen(process_obj)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/popen_spawn_posix.py”, line 32, in init
super().init(process_obj)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/popen_fork.py”, line 19, in init
self._launch(process_obj)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/popen_spawn_posix.py”, line 47, in _launch
reduction.dump(process_obj, fp)
File “/home/ihexx/anaconda3/envs/boom/lib/python3.6/multiprocessing/reduction.py”, line 60, in dump
ForkingPickler(file, protocol).dump(obj)
MemoryError

How about writing your own data loader in a separate thread? This can run in parrallel to your training loop.

The data loading can look like this (I wrote the data caching in RAM but it can work for parrallel caching in GPU as well, i think):
(cache_size is the number of images or data to be cached parrallelly)

from threading import Thread
from time import sleep
import queue 
import time


trfqueue = queue.Queue(maxsize=20) 
tsfqueue = queue.Queue(maxsize=20) 

 
cache_size= 10
templist = []
trfcount=0
tsfcount=0

def LoadinThread(dirPath1,runtimeLength):
  
  fcount=0
  global n_trainingFiles
  global n_testFiles
  n_testFiles=0
  n_trainingFiles=0
  
  trfqueue.queue.clear()
  tsfqueue.queue.clear()
  
  #print("---"+dirPath1)
  
  filelist= os.listdir(dirPath1)       
  for file in filelist:
    #print(fcount,runtimeLength )

    if file[-3:] == 'jpg' and (  fcount < 12  or runtimeLength==2) :
      fcount=fcount+1

      if dirPath1[-5:] == "Test/" :
        jpg, anno = GetData(file)
        # HERE you can copy the data to GPU or keep in RAM
        # 
        templist=[]
        templist.append(jpg)
        templist.append(anno)
        templist.append(file)

        n_testFiles = int( len(filelist)/2)
        while tsfqueue.qsize() > cache_size:
          time.sleep(1) 
        #print("Entry in tsfq")
        tsfqueue.put(templist)

      else:
        jpg, anno = GetData(file)
        # HERE you can copy the data to GPU or keep in RAM
        # 
        
        flist.append(file)
        templist=[]
        templist.append(jpg)
        templist.append(anno)
        templist.append(file)
        n_trainingFiles=  int( len(filelist)) /2
        #print("at training step", n_trainingFiles)
        while trfqueue.qsize() > cache_size:
          #print("sleep at trf q")
          time.sleep(2) 
        trfqueue.put(templist)
  print("thread stopped !! ")
 


def StartDataCaching(dirpath1, runtimeLength):
    thread1 = Thread(target = LoadinThread, args = (dirpath1,runtimeLength, ))


    thread1.start()
 

The training loop can look like this :

 def trainmodel(model, BATCH_SIZE,  lr_base, lr_max, runtimeLength):
    
    StartDataCaching(dirpath,runtimeLength)
    lr=lr_base
    fcount=0
    running_loss = 0.0
    current_loss=0.0
    while trfqueue.qsize() < 3 or n_trainingFiles ==0:
      time.sleep(2)
    #print("training size", n_trainingFiles, "q size " , trfqueue.qsize())
    lr_incr= (lr_max-lr)/n_trainingFiles
    
    outputlist= []
    
    for t in range(0, int(n_trainingFiles) ):

      
        
        #fstart = datetime.now()
        if trfqueue.empty():
          while trfqueue.empty():
            time.sleep(1)

        dataitem = trfqueue.get()
        inputs_reshaped, labels = dataitem[0],dataitem[1]
        filename= dataitem[2]

        SZ,_,_,_=inputs_reshaped.shape
        
        for i in range(0, SZ, BATCH_SIZE):
            optimizer = optim.Adam(model.parameters(), lr= lr) #1e-9
            inputs1, labels1 = (inputs_reshaped[i:i+BATCH_SIZE]), (labels[i:i+BATCH_SIZE])
            optimizer.zero_grad()

            loss = criterion(model(inputs1), labels1)
            loss.backward()        
            optimizer.step()
            lr=lr+lr_incr 
            current_loss += loss.data.item()
        #fpend = datetime.now()
        templist=[]
        templist.append( lr)
        templist.append( current_loss)
        templist.append( filename)
        
        outputlist.append( templist )
         
        #print("flist len ",len(flist))
        #flist[t] +

        print( filename+" - TRN .. "+repr(fcount)+ "/"+repr(n_trainingFiles) + " loss: "+repr(round( current_loss,2) ))
        running_loss += current_loss
        current_loss=0.0
        #print("FP: "( fpend-fend).total_seconds() )
        
    lr=lr_base
    
    print('Current Training: avg loss: %.3f' % ( running_loss/fcount ) )
    running_loss = 0.0
    return outputlist

However, keep in mind that if your GPU memory is 8GB and you batch size already fills the memory then you can’t additionally cache another 8 GB in GPU. It will through memory error.

1 Like

This worked pretty great thanks.
I managed to get parallel loading working using the standard dataloader as described in my initial question (turns out the error I was getting was just because my batch size was too high)

But that didn’t make my pipeline any faster until I tried your Queue idea.

COpying data to GPU went from 5 seconds (serial) to 0.0002 seconds (hidden cost by Queue pipeline

Thank you :slight_smile:

1 Like