Move numpy dataset to GPU memory

I’m looking to move my dataset to GPU memory (It’s fairly small and should fit). I thought something like this would work, but I end up with CUDA Error: initialization error:

class MyDataSet(Dataset):
    def __init__ (self, X,y,device='cpu'):
        '''
        So that we can move the entire dataset to the GPU.
        :param X: float32 data scaled numpy array
        :param y: float32 data scaled numpy vector
        :param device: 'cpu' or 'cuda:0'
        '''
        self.X = torch.from_numpy(X).to(device)
        # y vector needs to be in a column vector (or at least it
        # did in the normal dataset.)
        self.y = torch.from_numpy(y[:,None]).to(device)

    def __len__(self):
        return list(self.X.size())[0]

    def __getitem__(self, item):
        return self.X[item], self.y[item]

Then I put it into use it to initialize a training set:

# device = 'cuda:0', X_train, and y_train are numpy arrays, float32
train = MyDataSet(X_train, y_train,device=device)
loader = DataLoader(train,batch_size=256,shuffle=True, num_workers=1)

I initialize my model and also send it to the GPU.

I start my training loop and it dies on the ‘enumerate(loader,0)’

for epoch in range(300):
    running_loss = 0.0
    for i, data in enumerate(loader, 0):  # <<--- This is where the error is thrown

        inputs, labels = data[0].to(device), data[1].to(device)
        #inputs, labels = data  # <-- in theory if data is all on GPU I shouldn't need to 'move' it there again.
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

Traceback (most recent call last):
File “nn.py”, line 116, in
for i, data in enumerate(loader, 0):
File “/torch/utils/data/dataloader.py”, line 819, in next
return self._process_data(data)
File “/torch/utils/data/dataloader.py”, line 846, in _process_data
data.reraise()
File “/torch/_utils.py”, line 369, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.

Original Traceback (most recent call last):
File “torch/utils/data/_utils/worker.py”, line 178, in _worker_loop
data = fetcher.fetch(index)
File “torch/utils/data/_utils/fetch.py”, line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “torch/utils/data/_utils/fetch.py”, line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/MyDataSet.py”, line 21, in getitem
return self.X[item], self.y[item]
RuntimeError: CUDA error: initialization error

On the second part of the error it seems to be having a problem with the getitem of my custom Dataset.

When I run my custom Dataset but force it to be CPU-only, then everything works as expected (and the same as if I used pytorch’s normal DataSet class. I can also use the CPU-based Dataset to inside my traiing loop to push each mini-batch to the GPU and that works too, but is actually slower than just doing it all on the CPU due to the overhead of transfer.

when I do:

self.X = torch.from_numpy(X).to(device)
print(self.X)

It shows me that the data appears to have moved to the GPU:

---SNIP--
-3.3333e-01,  1.5024e+04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')

I feel like I need to put some sort of ‘cuda syncronize’ somewhere (probably data loader) so the the enumerate waits until the dataset is transfered. Not sure what that might be.

Thoughts?

Weirdly, setting ‘num_workers’ = 0 seems to allow it to work. ??

# this works and doesn't throw the errors from above
loader = DataLoader(train,batch_size=256,shuffle=True, num_workers=0) 

While that seems to fix the error, I’m wondering if there is a way to structure my code for better performance on the GPU? When num_workers = 0 performance is still limited as only 1 core is used at 100% to go over the epoch loop. Using a small batch_size is much slower than using a large batch_size. I thought by moving the data to the CPU, I wouldn’t be experiencing memory traffic (i’m assuming it’s memory traffic) or are some parts of the training loop that are necessarily on the CPU and need to pull results from the GPU?

When using ‘cpu’ only, all my cores are used. I’m assuming some of that is numpy backend + python interpreter

Is there something else I’m doing wrong that is limiting performance of ‘on-GPU’ datasets? (some setting or issue with my epoch loop or dataloader for example)

Each worker in your DataLoader will try to create a CUDA context, since you are using CUDATensors, which will raise this error.
You could use the 'spawn' method for multiprocessing as described here.

Here is a small example:

import torch.multiprocessing as mp

class MyDataset(Dataset):
    def __init__(self, device='cpu'):
        super(MyDataset, self).__init__()
        self.data = torch.randn(100, 1, device=device)
        
    def __getitem__(self, index):
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)
    
def main():
    dataset = MyDataset(device='cuda')
    loader = DataLoader(
        dataset,
        num_workers=2
    )
    
    for data in loader:
        print(data.device)
    
if __name__=='__main__':
    mp.set_start_method('spawn')
    main()
1 Like