How does prefetch factor really work?

The docs ( — PyTorch 1.13 documentation) define prefetch_factor as the “Number of batches loaded in advance by each worker,” implying that if I have W workers and P prefetch factor, the total batches loaded would be WxP, although are all of these batches that have to be loaded in advance? Or are these the max batches that could be loaded in advance?

For example, consider the following code snippet:

for batch in range(data_loader):
  # forward & backward passes on the model

For the above code snippet, does the data_loader function in the following way:

  1. Before the first iteration, WxP batches are loaded
  2. After WxP iterations, another set of WxP batches is loaded

Or does the data_loader function in this way:

  1. Before the first iteration, WxP batches are loaded
  2. After the first iteration, now in the host memory WxP - 1 batches reside, so any worker from the worker pool fetches a new batch to make it to WxP.

If in case the latter is correct, does the 2nd step happen synchronously or asynchronously? i.e. do I need to wait to fetch a new batch to make total batches to WxP and then proceed to the iteration, or given I already have WxP - 1 batches to feed, I can process the remaining batch asynchronously?

Thanks in advance for answering!

I don’t think either description is exactly correct.

In general, DataLoader tries to return the first batch as soon as possible upon request (rather than load WxP batches in advance before returning). With prefetch_factor > 0, while the forward and backward passes happen, DataLoader tries to prepare as many subsequent batches as possible up to the limit set by the prefetch_factor and saves those batches in a buffer. When the model requests the next batch, DataLoader immediately pops off the first batch from the buffer, regardless of whether the buffer is full or not. Then, it goes back and tries to fill up the buffer while the model executes.

1 Like

This code snippet also illustrates your description in case it’s helpful.