Simultaneously preprocess a batch on CPU and run forward/backward on GPU

Yes, if multiple workers are used they will prepare the next batch and won’t wait until the training loop finished. Once the queue is full, the workers will wait as @nivek described and you can change the behavior via the prefetch_factor.
Here is a small example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100).view(-1, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            print("worker {} loading index {}".format(worker_info.id, idx))
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()
loader = DataLoader(dataset, num_workers=2)

for data in loader:
    print("executing training")
    time.sleep(5)
    print("training done")

Output:

worker 0 loading index 0
worker 1 loading index 1
worker 0 loading index 2
worker 1 loading index 3
worker 0 loading index 4
executing training
worker 1 loading index 5
training done
executing training
worker 0 loading index 6
training done
executing training
worker 1 loading index 7
training done
executing training
worker 0 loading index 8
training done
executing training
worker 1 loading index 9
training done
executing training
...

The data loading is fast in this case as it’s simple indexing and you can see that the workers directly fill the queue before the “training” (which is a sleep call here) can even start. Once the data is consumed, the next worker will start loading and processing the next batch.

2 Likes