Hi,
I have a use case where my Dataset’s __getitem__
fetches a Tensor somewhere in the memory of the main process, transfers it to the GPU to do some basic image processing faster than it ever could on the CPU, and returns a CUDA tensor directly. Then there’s a DataLoader
on top of that.
The DataLoader
works fine when using num_workers=0
, however I’m getting errors whenever I try to use multiprocessing by increasing num_workers
…
Here’s a MWE :
import torch
class CudaDataset(torch.utils.data.Dataset):
def __init__(self, device):
self.tensor_on_ram = torch.Tensor([1, 2, 3])
self.device = device
def __len__(self):
return len(self.tensor_on_ram)
def __getitem__(self, index):
return self.tensor_on_ram[index].to(self.device)
ds = CudaDataset(torch.device('cuda:0'))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=1, pin_memory=False, num_workers=2)
# First time runs with no issue at all
for i in dl:
pass
## Let's do it a second time
for i in dl: # Here it throws an error
pass
Here’s the error :
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-1-3f858a29a121>", line 12, in __getitem__
return self.tensor_on_ram[index].to(self.device)
File "/usr/local/lib/python3.7/dist-packages/torch/cuda/__init__.py", line 207, in _lazy_init
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
I’ve tried adding the line torch.multiprocessing.set_start_method('spawn')
at the top, but then I get DataLoader worker (pid(s) 1078) exited unexpectedly
I’m not sure whether that use case is possible, I just wanted to benchmark the performance gain that I could obtain with more workers, as this is currently my bottleneck.
Does anyone know a way to work around this ? Thanks