Simultaneously preprocess a batch on CPU and run forward/backward on GPU

Hi, this is a basic understanding question: how can I make a DataLoader start preparing the next batch while a model runs .forward() and .backward() on the previous batch?

When we write a loop to train or run inference with a PyTorch model, we loop over a DataLoader to get the next batch of samples, then run model.forward(x) within the loop (and if training also calculate loss and run model.backward()).

Simple inference loop:

for batch in dataloader: #loads next batch, which happens on CPU and can take a while.
    batch_tensors = batch["X"].to(device_name)
    logits = model.forward(batch_tensors)  #runs forward pass on GPU
    # the next batch doesn't start loading until this finishes (I assume)

In a scenario where the DataLoader uses several CPU workers to parallelize data, and the model.forward() and model.backward() are on the GPU, it seems like the workload will always bet trading off between the CPU (while DataLoader gets the next batch) and GPU (while model.forward runs, for instance). In my use case, pre-processing with the DataLoader can be heavy such that the time to prepare a batch is equal to the time to run the forward pass. So, it seems like if the DataLoader would immediately start preparing the next batch instead of waiting for the loop to complete, inference would be twice as fast.

Am I missing something here? Perhaps the DataLoader does actually start working on the next batch immediately rather than waiting for the top of the loop to be reached again.

If not, is there a way to achieve this simultaneous use of CPU preprocessing and GPU forward/backward passes?

Thank you for your help, I understand this may be an obvious or simple question for more experienced users.

You should be able to do that with the argument prefetch_factor for DataLoader (documentation):

  • prefetch_factor (int, optional , keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default: 2)

Let us know if you are seeing unexpected behaviors.

1 Like

Yes, if multiple workers are used they will prepare the next batch and won’t wait until the training loop finished. Once the queue is full, the workers will wait as @nivek described and you can change the behavior via the prefetch_factor.
Here is a small example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(100).view(-1, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            print("worker {} loading index {}".format(worker_info.id, idx))
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()
loader = DataLoader(dataset, num_workers=2)

for data in loader:
    print("executing training")
    time.sleep(5)
    print("training done")

Output:

worker 0 loading index 0
worker 1 loading index 1
worker 0 loading index 2
worker 1 loading index 3
worker 0 loading index 4
executing training
worker 1 loading index 5
training done
executing training
worker 0 loading index 6
training done
executing training
worker 1 loading index 7
training done
executing training
worker 0 loading index 8
training done
executing training
worker 1 loading index 9
training done
executing training
...

The data loading is fast in this case as it’s simple indexing and you can see that the workers directly fill the queue before the “training” (which is a sleep call here) can even start. Once the data is consumed, the next worker will start loading and processing the next batch.

1 Like