Torch dataloader num_workers>0 not spawning workers

I’m currently working on porting code from Keras to PyTorch. I’m working with many GPUs and CPUs so it’s important to have batch generation happening in parallel. My problem is that I’m trying to use the num_workers argument on the DataLoader class, but am meeting with errors. Current relevant toy code:

import torch
torch.multiprocessing.set_start_method('spawn')
from torch.utils.data import Dataset, DataLoader

X_train = torch.randn(100,100)
y_train = torch.randn(100,100)
w_train = torch.randn(100,100)

class My_Dataset(Dataset):
     def __init__(self, x_input, y_labels, w_labels):
        self.y_labels = y_labels
        self.w_labels = w_labels
        self.x_input = x_input
        self.len = self.x_input.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if __name__ == '__main__':
            print('Error! being run by {}'.format(__name__))
        
        X = self.x_input[idx]
        y = self.y_labels[idx]
        w = self.w_labels[idx]
        return X, y, w

training_dataset = My_Dataset(X_train, y_train, w_train)
training_dataloader = DataLoader(training_dataset, batch_size=10, shuffle=True, num_workers=15, pin_memory=False)

for batch_num, (inputs, labels_y, labels_w) in enumerate(training_dataloader):
    (inputs, labels_y, labels_w) =(inputs.cuda(), labels_y.cuda(), labels_w.cuda()) 
    print(inputs)

Removing torch.multiprocessing.set_start_method(‘spawn’) causes the code to run, but the batch generation runs in the main process (the error message I wrote into the dataset prints, also on my non-toy problem it takes unacceptably long). I’ve tried both ‘spawn’ and ‘forkserver’, but both fail and I can’t even figure out how to get a relevant error message. Right now the error just says one of the processes failed and setting num_workers =0 might get a better message.

Python version = 3.7.9, torch version = 1.7.0

Unsure if you are using Windows, but you could try to use the if-clause protection as described here.

Could you also post the error messages you are seeing?

I’m on Linux. I implemented the if-clause protection, but it didn’t change the behavior of the code.

Behavior with torch.multiprocessing.set_start_method(‘fork’):
prints 100x
Error! being run by __main__

Behavior with torch.multiprocessing.set_start_method(‘spawn’):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    871         try:
--> 872             data = self._data_queue.get(timeout=timeout)
    873             return (True, data)

~/.conda/envs/torch_env/lib/python3.7/multiprocessing/queues.py in get(self, block, timeout)
    103                     timeout = deadline - time.monotonic()
--> 104                     if not self._poll(timeout):
    105                         raise Empty

~/.conda/envs/torch_env/lib/python3.7/multiprocessing/connection.py in poll(self, timeout)
    256         self._check_readable()
--> 257         return self._poll(timeout)
    258 

~/.conda/envs/torch_env/lib/python3.7/multiprocessing/connection.py in _poll(self, timeout)
    413     def _poll(self, timeout):
--> 414         r = wait([self], timeout)
    415         return bool(r)

~/.conda/envs/torch_env/lib/python3.7/multiprocessing/connection.py in wait(object_list, timeout)
    920             while True:
--> 921                 ready = selector.select(timeout)
    922                 if ready:

~/.conda/envs/torch_env/lib/python3.7/selectors.py in select(self, timeout)
    414         try:
--> 415             fd_event_list = self._selector.poll(timeout)
    416         except InterruptedError:

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/_utils/signal_handling.py in handler(signum, frame)
     65         # Python can still get and update the process status successfully.
---> 66         _error_if_any_worker_fails()
     67         if previous_handler is not None:

RuntimeError: DataLoader worker (pid 17676) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-7-b7dbc62fb341> in <module>
      5 
      6 if __name__ == '__main__':
----> 7     main()

<ipython-input-7-b7dbc62fb341> in main()
      1 def main():
----> 2     for batch_num, (inputs, labels_y, labels_w) in enumerate(training_dataloader):
      3         (inputs, labels_y, labels_w) =(inputs.cuda(), labels_y.cuda(), labels_w.cuda())
      4     #print(inputs)
      5 

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1066 
   1067             assert not self._shutdown and self._tasks_outstanding > 0
-> 1068             idx, data = self._get_data()
   1069             self._tasks_outstanding -= 1
   1070             if self._dataset_kind == _DatasetKind.Iterable:

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _get_data(self)
   1032         else:
   1033             while True:
-> 1034                 success, data = self._try_get_data()
   1035                 if success:
   1036                     return data

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    883             if len(failed_workers) > 0:
    884                 pids_str = ', '.join(str(w.pid) for w in failed_workers)
--> 885                 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
    886             if isinstance(e, queue.Empty):
    887                 return (False, None)

RuntimeError: DataLoader worker (pid(s) 17676) exited unexpectedly

Behavior with torch.multiprocessing.set_start_method(‘forkserver’):

---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    871         try:
--> 872             data = self._data_queue.get(timeout=timeout)
    873             return (True, data)

~/.conda/envs/torch_env/lib/python3.7/multiprocessing/queues.py in get(self, block, timeout)
    104                     if not self._poll(timeout):
--> 105                         raise Empty
    106                 elif not self._poll():

Empty: 

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-7-b7dbc62fb341> in <module>
      5 
      6 if __name__ == '__main__':
----> 7     main()

<ipython-input-7-b7dbc62fb341> in main()
      1 def main():
----> 2     for batch_num, (inputs, labels_y, labels_w) in enumerate(training_dataloader):
      3         (inputs, labels_y, labels_w) =(inputs.cuda(), labels_y.cuda(), labels_w.cuda())
      4     #print(inputs)
      5 

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1066 
   1067             assert not self._shutdown and self._tasks_outstanding > 0
-> 1068             idx, data = self._get_data()
   1069             self._tasks_outstanding -= 1
   1070             if self._dataset_kind == _DatasetKind.Iterable:

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _get_data(self)
   1032         else:
   1033             while True:
-> 1034                 success, data = self._try_get_data()
   1035                 if success:
   1036                     return data

~/.conda/envs/torch_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    883             if len(failed_workers) > 0:
    884                 pids_str = ', '.join(str(w.pid) for w in failed_workers)
--> 885                 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
    886             if isinstance(e, queue.Empty):
    887                 return (False, None)

RuntimeError: DataLoader worker (pid(s) 17746, 17747, 17748, 17749, 17750, 17751, 17752, 17753, 17754, 17755, 17756, 17757, 17758, 17759, 17760) exited unexpectedly

If anyone has suggestions on getting more specific error messages, that would also be very helpful. Right now I don’t know how to make things print from inside the process. Also very helpful if someone has a toy example of a dataloader with num_workers>0 that plays nice

Based on the first post I assume that your code runs fine with num_workers=0 for the full epoch and doesn’t yield any error?
If that’s the case, could you check if you have enough shared memory available?

As I understand, this is your main concern:

It may look like batch generation is been doing only in the main, but if you check the output of the Linux top command throughout the program execution, you will clearly see the additional number of processes has been fired up, and they were busy utilizing some PCU compute power.

With regard to the printing of error about executing in main, I am not sure, but I think every instance of dataloader has its own instance of dataset.

Yes, the code runs with ‘num_workers=0’ for the full epoch and doesn’t yield any error (other than the one I wrote in).

As for shared memory available, the system has several hundred GB of memory. Should be more than enough for 100*100 random floats even with 15 workers. Am I missing something special about shared memory?

If it is spinning up more processes, they’re not doing anything to make things faster. On my problem with the GPU also involved I see that batch generation is throttling my process (because adding operations to __get_item__() makes things take longer) and setting num_workers=15 doesn’t speed things up (like it does in keras).

Yes, system RAM isn’t used as shared memory and you can adapt it in your system step.

Projects could be really complicated. To understand what is a problem I would try to profile my program and find the bottleneck. It could be in any major moving part of the pipeline: from preparing data and generating the batches to training the models on the distributed system. So, for example I would check if my cpus are struggling to prepare data for GPU/s and GPU/s are spending a lot of time waiting for the data or is it perhaps vice versa and the model is too big, so batches are waiting to come into GPU for too long. It could also be the problem to utilize all GPUs correctly using distributed parallel training and so on.

On my problem if I just have the dataloader supply a a tensor from memory it takes ~1 s/iteration. However, my setup requires that during training I add a tensor of random noise to the tensor. Generating the random tensor is CPU intensive and the time per iteration jumps to ~6s/iteration with 1 CPU (I can’t precompute the noise because I add different noise every epoch). In Keras adding workers brings this back down to ~1s/iteration because I can use many CPUs to generate the random noise and prepare the batch. Right now I haven’t found an analog in tensorflow that lets me do this successfully.

Ok, now I can see you have some troublesome point in your pipeline. It may help to post some code snippets on how you create tensors and add noise tensors so other people in pytorch community can help to improve performance.

Looks like this was probably the issue. The system only has 45k of shared. I found this thread to be relevant.

Unfortunately, I don’t think I’m going to be able to make changes to the system settings.

I had the same problem, I downgraded to Python 3.6 and It worked.

In my experience, num_workers has unpredictable behavior for many threads. Having 16 logical threads I can usually only run around 8 as a stable configuration and only get a memory error once in a blue moon.

If anyone else is having problems I would suggest running 4-6 num_workers. More is often not needed.