How do I use pinned memory with multiple workers in a PyTorch DataLoader?

(Full disclosure: I’ve asked this question at StackOverflow as well. Whomever answers it here my wish to answer it there too: How do I use pinned memory with multiple workers in a PyTorch DataLoader? - Stack Overflow)

I’m learning about methods to accelerate the training of deep-learning models using PyTorch. From Pin_memory and num_workers in pytorch data loaders, it seems that it should be possible (whether it is desirable is separate question) to have both pin_memory=True and num Workers=1 (or 2, 3, etc.) as arguments of a DataLoader, if some care is taken with the collator function passed as the named argument collate_fn to the DataLoader. To test a few scenarios, I wrote the following code:

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

def collator(batch):
    X, Y = [], []
    for x, y in batch:
        X += [x, ]
        Y += [y, ]

    return torch.stack(X), torch.stack(Y)
    

def pin_collator(batch):
    X, Y = collator(batch)

    return X.pin_memory(), Y.pin_memory()


class TestDataset(Dataset):
    def __init__(self):
        self.x = [torch.randn(3) for _ in range(4)]
        self.y = [(x*x).sum() for x in self.x]

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    

dataset = TestDataset()

print("Not pinned, 0 workers")
dataloader = DataLoader(dataset, num_workers=0, batch_size=2, collate_fn=collator)
for batch in dataloader:
    print(batch)

print("Pinned, 0 workers")
pin_dataloader = DataLoader(dataset, num_workers=0, batch_size=2, collate_fn=pin_collator, pin_memory=True)
for batch in pin_dataloader:
    print(batch)

print("Not pinned, 1 worker")
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, collate_fn=collator)
for batch in dataloader:
    print(batch)

print("Pinned, 1 worker")
pin_dataloader = DataLoader(dataset, num_workers=1, batch_size=2, collate_fn=pin_collator, pin_memory=True)
for batch in pin_dataloader:
    print(batch)

However, in Ubuntu 22.04, either “native” or under WSL, and using pytorch 2.0 and cuda 11.8, I got the following error message pertaining to the fourth data-loading block (“Pinned, 1 worker”):

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user_name/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user_name/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/tmp/ipykernel_822116/2862910865.py", line 17, in pin_collator
    return X.pin_memory(), Y.pin_memory()
RuntimeError: CUDA driver error: initialization error

Here’s the full Traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 51
     49 print("Pinned, 1 worker")
     50 pin_dataloader = DataLoader(dataset, num_workers=1, batch_size=2, collate_fn=pin_collator, pin_memory=True)
---> 51 for batch in pin_dataloader:
     52     print(batch)

File ~/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/utils/data/dataloader.py:634, in _BaseDataLoaderIter.__next__(self)
    631 if self._sampler_iter is None:
    632     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    633     self._reset()  # type: ignore[call-arg]
--> 634 data = self._next_data()
    635 self._num_yielded += 1
    636 if self._dataset_kind == _DatasetKind.Iterable and \
    637         self._IterableDataset_len_called is not None and \
    638         self._num_yielded > self._IterableDataset_len_called:

File ~/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1346, in _MultiProcessingDataLoaderIter._next_data(self)
   1344 else:
   1345     del self._task_info[idx]
-> 1346     return self._process_data(data)

File ~/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1372, in _MultiProcessingDataLoaderIter._process_data(self, data)
   1370 self._try_put_index()
   1371 if isinstance(data, ExceptionWrapper):
-> 1372     data.reraise()
   1373 return data

File ~/miniforge3/envs/pytorch_2-0/lib/python3.10/site-packages/torch/_utils.py:644, in ExceptionWrapper.reraise(self)
    640 except TypeError:
    641     # If the exception takes multiple arguments, don't try to
    642     # instantiate since we don't know how to
    643     raise RuntimeError(msg) from None
--> 644 raise exception

I’m primarily concerned with a Linux implementation, and perhaps this additional information doesn’t really add to the question, but, for what’s worth, on Windows 11 both the third and fourth (after commenting out the third) data-loading blocks fail with the following error message:

RuntimeError: DataLoader worker (pid(s) 9364) exited unexpectedly

The full traceback is the same for both blocks:

---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\site-packages\torch\utils\data\dataloader.py:1133, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
   1132 try:
-> 1133     data = self._data_queue.get(timeout=timeout)
   1134     return (True, data)

File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\multiprocessing\queues.py:114, in Queue.get(self, block, timeout)
    113     if not self._poll(timeout):
--> 114         raise Empty
    115 elif not self._poll():

Empty: 

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[7], line 46
     44 print("Not pinned, 1 worker")
     45 dataloader = DataLoader(dataset, num_workers=1, batch_size=2, collate_fn=collator)
---> 46 for batch in dataloader:
     47     print(batch)
     49 print("Pinned, 1 worker")

File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\site-packages\torch\utils\data\dataloader.py:634, in _BaseDataLoaderIter.__next__(self)
    631 if self._sampler_iter is None:
    632     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    633     self._reset()  # type: ignore[call-arg]
--> 634 data = self._next_data()
    635 self._num_yielded += 1
    636 if self._dataset_kind == _DatasetKind.Iterable and \
    637         self._IterableDataset_len_called is not None and \
    638         self._num_yielded > self._IterableDataset_len_called:

File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\site-packages\torch\utils\data\dataloader.py:1329, in _MultiProcessingDataLoaderIter._next_data(self)
   1326     return self._process_data(data)
   1328 assert not self._shutdown and self._tasks_outstanding > 0
-> 1329 idx, data = self._get_data()
   1330 self._tasks_outstanding -= 1
   1331 if self._dataset_kind == _DatasetKind.Iterable:
   1332     # Check for _IterableDatasetStopIteration

File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\site-packages\torch\utils\data\dataloader.py:1295, in _MultiProcessingDataLoaderIter._get_data(self)
   1291     # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
   1292     # need to call `.task_done()` because we don't use `.join()`.
   1293 else:
   1294     while True:
-> 1295         success, data = self._try_get_data()
   1296         if success:
   1297             return data

File c:\Users\user_name\AppData\Local\mambaforge\envs\pytorch_2.0\lib\site-packages\torch\utils\data\dataloader.py:1146, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
   1144 if len(failed_workers) > 0:
   1145     pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1146     raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
   1147 if isinstance(e, queue.Empty):
   1148     return (False, None)

My CUDA installation behaves fine insofar as I can train and run inference with models using either the GPU (with and without pinned memory) or my CUDA GPU when num_workers=0.

Hence the title question: How do I use pinned memory with multiple workers in a PyTorch DataLoader?

It seems you want to re-implement a simple custom collate_fn including the ability to pin memory which is then failing due to a re-initialization of the CUDA context.
The easier approach would be to just use the pin_memory argument in the DataLoader and I’m unsure why you want to create multiple pools of pinned memory.
The details about the internal usage of pinned memory in the DataLoader can be found here including how the pin_memory_thread works.

Thank you for referring me to the details of the internal usage of pinned memory. I see now that there is no need for me to explicitly pin the data in my custom collator, as the pin_memory argument does that for me.

@ptrblck, would you mind posting your answer on the StackOverflow question linked at the top of this thread? I’d rather not do it myself, as it is not my answer.

If you want you could link to this thread, but I don’t have a StackOverflow account.