CUDA initialization error when DataLoader with CUDA Tensor

My dataset is small, and I want to load all my dataset into GPU memory when a dataset is created. Meanwhile, I still want to use torch.utils.data.DataLoader because of compatibility with other situations where I load my data on the fly.

My short working example is as follows.

import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch

data = np.array([[1,2,3], [4,5,6]])
# I move dataset to GPU first
ds = TensorDataset(torch.Tensor(data).cuda())                                 
dl = DataLoader(ds, batch_size=1, num_workers=1, shuffle=True) 
for x in dl:
    print(x)

However, it crashes.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-11b3bb8f6574> in <module>
      6 ds = TensorDataset(torch.Tensor(data).cuda())
      7 dl = DataLoader(ds, batch_size=1, num_workers=1, shuffle=True)
----> 8 for x in dl:
      9     print(x)

~/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    635                 self.reorder_dict[idx] = batch
    636                 continue
--> 637             return self._process_next_batch(batch)
    638 
    639     next = __next__  # Python 2 compatibility

~/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
    656         self._put_indices()
    657         if isinstance(batch, ExceptionWrapper):
--> 658             raise batch.exc_type(batch.exc_msg)
    659         return batch
    660 

RuntimeError: Traceback (most recent call last):
  File "/home/swyoon/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/swyoon/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/swyoon/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 40, in __getitem__
    return tuple(tensor[index] for tensor in self.tensors)
  File "/home/swyoon/.conda/envs/ml/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 40, in <genexpr>
    return tuple(tensor[index] for tensor in self.tensors)
RuntimeError: CUDA error: initialization error

Is using a DataLoader when all my data is loaded on GPU?
Is the error an intended feature of Pytorch?

4 Likes

When I write a custom data loader which simply batches through the TensorDatasest, everything is fine.
Then I guess the problem is the multiprocessing, so I tried num_workers=0 which disables multiprocessing in the original DataLoader.
Now it works.

2 Likes

If you would like to use multiple workers in your DataLoader, pass the data tensor as a CPU tensor to TensorDataset and push each batch to the GPU using:

ds = TensorDataset(torch.from_numpy(data))
dl = DataLoader(ds, batch_size=1, num_workers=1, shuffle=True)
for x in dl:
    x = x.to('cuda', non_blocking=True)

Otherwise multiple CUDA contexts will be initialized yielding your error.

8 Likes

@ptrblck Thanks for the reply.

  1. Which way would you recommend in terms of performance? The reason why I’m putting all my data to GPU first is to increase the speed.

  2. When using non_blocking=True, is it okay not to use pin_memory=True in DataLoader? The torch.Tensor.cuda() doc says non_blocking is effective when the data is in pin memory.

  1. If your data is already on the GPU, you won’t really need multiple workers, since most likely you are just slicing the data and passing the the model (or passing all the data at once). However, this approach uses some of your limited GPU memory, which might be used for e.g. a bigger model etc.

  2. The op will be non_blocking, if pin_memory was set to True, so you should do it. I’ve missed that part in my code snippet so thanks for pointing it out. :wink:

2 Likes

@ptrblck Now I think everything is clear. Thank you very much!

The above is exactly the same issue I’m running into now. I’ve got a working version that loads the data into gpu memory without cuda calls (or at least without causing this issue) via pyarrow.read_parquet and it’s 2.3x faster with 4 workers than it is with 0, despite all of the data being in pinned gpu memory.

Following the instructions for multiprocessing best practices I’ve tried to set torch’s multiprocessing start method to spawn (I’ve also tried forkserver) but when I do this I run into invalid device pointer: 0x7f8a8c000000 at /pytorch/aten/src/THC/THCCachingAllocator.cpp:301 exceptions.

If the data is already on GPU I’d expect 0 workers would be fast, but there’s a big performance hit, so much so that it’s not really worth doing from what I can see. I’m relatively new to multiprocessing so there’s a good chance I’m doing something wrong, and it may just make sense to stick to CPU dataloading to keep the GPU memory available for the model, but I’d love to figure out a more robust solution, particularly in light of in GPU preprocessing options like RAPIDS.

1 Like

It doesn’t seem to me that the data will be pre-loaded to the GPU. Also if the next operation depends on the data this doesn’t really gain any performance. Also if we’re moving data back to the CPU the data-loading isn’t overlapped. Is there any way to pre-load the data in the main CUDA context?

The easiest way to preload all data on GPU is by simply copying it there (Tensor.cuda()) and maintaining a Python list with all samples that you want to process. Then, instead of iterating over a dataset, you iterate over a Python list with pre-existing CUDA tensors. The reasons these multiprocessing data loaders exist are 1) datasets are typically much larger than a single GPU can hold resident in memory, 2) a single CPU cannot preprocess enough examples per second to saturate GPU throughput.

3 Likes