Expected a 'mps:0' generator device but found 'mps'

I’m trying to load custom data for a CNN via mps on a MacBook pro M3 pro but encounter the issue where the generator expects a mps:0 generator but gets mps

  • Python ver: 3.11.6
  • PyTorch ver: 2.1.1
  • PyVision ver: 0.16.1
  • Environment: Jupyter Notebook (on VSCode)

Code:

if torch.backends.mps.is_available():
mps_device = torch.device(‘mps:0’)
torch.set_default_device(mps_device)
else:
print (‘MPS device not found.’)

trainset = datasets.ImageFolder(root=train_dir, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=mps_device))

model = models.resnet18(weights=‘IMAGENET1K_V1’)

Error stack:

File /opt/homebrew/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.next(self) 627 if self._sampler_iter is None: 628 # TODO(Bug in dataloader iterator found by mypy · Issue #76750 · pytorch/pytorch · GitHub) 629 self._reset() # type: ignore[call-arg] → 630 data = self._next_data() 631 self._num_yielded += 1 632 if self._dataset_kind == _DatasetKind.Iterable and \ 633 self._IterableDataset_len_called is not None and \ 634 self._num_yielded > self._IterableDataset_len_called:

75 if func in _device_constructors() and kwargs.get(‘device’) is None: 76 kwargs[‘device’] = self.device —> 77 return func(*args, **kwargs)
RuntimeError: Expected a ‘mps:0’ generator device but found ‘mps’

Try to use the mps backend explicitly instead of using set_default_device.

1 Like