Device error with dataloaders

Torch version 2.1.0.dev20230702+cu121

import torch
from torch.utils import data

torch.set_default_device('cuda')

class NullDataset(data.Dataset):
    def __len__(self) -> int:
        return 100

dataloader = data.DataLoader(NullDataset(), batch_size=64, shuffle=True, generator=torch.Generator(device='cuda'))

for data in dataloader:
    print(data)

gives errors at for data in dataloader: with

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "blabla\venv\Lib\site-packages\torch\utils\data\dataloader.py", line 633, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "blabla\venv\Lib\site-packages\torch\utils\data\dataloader.py", line 676, in _next_data
    index = self._next_index()  # may raise StopIteration
            ^^^^^^^^^^^^^^^^^^
  File "blabla\venv\Lib\site-packages\torch\utils\data\dataloader.py", line 623, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "blabla\venv\Lib\site-packages\torch\utils\data\sampler.py", line 289, in __iter__
    for idx in self.sampler:
  File "blabla\venv\Lib\site-packages\torch\utils\data\sampler.py", line 167, in __iter__
    yield from map(int, torch.randperm(n, generator=generator).numpy())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "blabla\venv\Lib\site-packages\torch\utils\_device.py", line 76, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

I deliberately omitted the definition of __getitem__ for NullDataSet to show that it isn’t the cause of the error. How should I fix this? This error seems to be introduced by commit Do not materialize entire randperm in RandomSampler (#103339), calling .numpy() on a cuda tensor. Previously there was torch.randperm(n, generator=generator).tolist() instead of map(int, torch.randperm(n, generator=generator).numpy())

1 Like