DataLoader returns CPU tensors

Hello, I’m loading the MNIST dataset using torchvision.

test_iter = torch.utils.data.DataLoader(
    datasets.MNIST('.data/mnist', train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True)

This code is returning tensors on CPU. Is there a way to load the batches on the GPU from the DataLoader? For instance, in torchtext there is a device option in the loaders’ signature.

You could create your own Dataset by deriving from datasets.MNIST.
However, since the data and target will be on GPU, you cannot apply the torchvision.transforms, because they are performed on PIL.Images.
So we would have to perform the normalization manually.
Also, I don’t know, if there is a workaround of using num_workers > 0, but it seems CUDA would be initialized several times, yielding an error.

Considering all these issues, here is a small code example:

class MyMNIST(datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, device='cpu'):
        super(MyMNIST, self).__init__(root, train=True, transform=None, target_transform=None, download=False)
        if self.train:
            self.train_data = self.train_data.to(device)
            self.train_labels = self.train_labels.to(device)
        else:
            self.test_data = self.test_data.to(device)
            self.test_labels = self.test_labels.to(device)
        
    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]
            
        # You cannot perform the PIL transformations, since the data is already on GPU
        img = img.float()
        img = img / 255.
        img = img - 0.1307
        img = img / 0.3081
        
        return img, target

dataset = MyMNIST(root='YOUR_PATH', train=True, device='cuda:0')
loader = DataLoader(dataset, batch_size=2, num_workers=0, shuffle=True, pin_memory=False)
loader_iter = iter(loader)
data, target = loader_iter.next()
print(data.type())
> torch.cuda.FloatTensor