Prefetch_factor in Dataloader

Hi,

According to doc, the prefetch_factor is the number of samples loaded in advance by each worker, and it’s 2 by default.
I’m wondering what’s the meaning of pre-loading merely 2 examples, instead of pre-loading, say, 2 batches of data.
Does pre-loading a few examples really help?

Thanks

As with many things, the best way to answer a setup-dependent question like that is to instrument a working example. Depending on the speed of model execution, the speed of storage, the number of workers, the OS filesystem caching policy, the “optimal” prefetch factor will vary, so if you find evidence that this isn’t a sane default, please open an upstream issue or PR!

It does pre-load 2 batches of data (not samples). The docstring on the master branch has been corrected.

3 Likes

I was confused.
From the class _MultiProcessingDataLoaderIter in the source file torch\\utils\\data\\dataloader.py, it can be seen from the _reset function that prefetch_factor refers to “Number of samples loaded in advance by each worker.”

    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        # A list of booleans representing whether each worker still has work to
        # do, i.e., not having exhausted its iterable dataset object. It always
        # contains all `True`s if not using an iterable-style dataset
        # (i.e., if kind != Iterable).
        # Not that this indicates that a worker still has work to do *for this epoch*.
        # It does not mean that a worker is dead. In case of `_persistent_workers`,
        # the worker will be reset to available in the next epoch.
        self._workers_status = [True for i in range(self._num_workers)]
        # We resume the prefetching in case it was enabled
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                return_idx, return_data = self._get_data()
                if isinstance(return_idx, _utils.worker._ResumeIteration):
                    assert return_data is None
                    resume_iteration_cnt -= 1
        # prime the prefetch loop
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()

The function self._try_put_index() puts one index in self._index_queues of a sample at a time.
And self._send_idx += 1 in self._try_put_index().

I see.
prefetch_factor * batch_size * _num_workers is true when using a batch sampler.
If batch_size is None, batch_sampler is not used, so it is prefetch_factor * _num_workers in this condition.
Have I understood correctly?

pytorch/torch/utils/data/dataloader.py : DataLoader

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
...
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
...

    @property
    def _index_sampler(self):
        # The actual sampler used for generating indices for `_DatasetFetcher`
        # (see _utils/fetch.py) to read data at each time. This would be
        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
        # We can't change `.sampler` and `.batch_sampler` attributes for BC
        # reasons.
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler