err: Expected a 'mps:0' generator device but found 'cpu'

I tried to run the code to split data like this:

#specify the device to use
mps_device = torch.device("mps:0")

dataset = ImageFolder(root=dataset_path, transform=data_transforms)
# Define the percentage for each split
train_ratio = 0.8
test_ratio = 0.2

total_size = len(dataset)
train_size = int(train_ratio * total_size)
test_size = total_size - train_size

train_set, test_set = random_split(dataset, [train_size, test_size]) //error on this line

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)

error trace:
Traceback (most recent call last):
File “/Users/vietpham1023/Desktop/python-resource-yoga-pose/”, line 85, in
train_set, test_set = random_split(dataset, [train_size, test_size])
File “/Users/vietpham1023/anaconda3/lib/python3.11/site-packages/torch/utils/data/”, line 420, in random_split
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
File “/Users/vietpham1023/anaconda3/lib/python3.11/site-packages/torch/utils/”, line 77, in torch_function
return func(*args, **kwargs)
RuntimeError: Expected a ‘mps:0’ generator device but found ‘cpu’

I tried Google search but no luck.
I also tried to re-assign the generator in the stack trace:

# Set the generator device to 'mps:0'
generator = torch.Generator(device='mps:0')

# Use the generator in the randperm function
indices = randperm(sum(lengths), generator=generator).tolist() 

but got this error:
RuntimeError: Expected a ‘mps:0’ generator device but found ‘mps’

continue to debug I trace to the last function in the stack trace which is:

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if func in _device_constructors() and kwargs.get('device') is None:
            kwargs['device'] = self.device  #I replaced this line with: kwargs['device'] = 'mps:0' but no luck
        return func(*args, **kwargs)'

also I saw in this line kwargs[‘device’] = self.device, when debugging it has self.device = ‘mps:0’
but the ‘generator’ key in kwargs dict is just ‘mps’, I tried to re-assign kwargs[‘generator’].device = self.device but got an error said: I can’t write to it.

Thanks and I appreciate any help.

Similar issues. Will post if i find a solution.

1 Like
Without using torch.set_default_device, the code segment below loads data onto the CPU and then initializes the model, transferring it to the GPU for processing with the following code:
mps_device = torch.device("mps:0")
model = ConvNet(num_classes)
In the training loop, while iterating through the train_loader, you can sequentially pass data to the GPU for processing to avoid encountering errors:
# Training loop
for epoch in range(num_epochs):
    for step, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()  # Zero the gradient buffers
        inputs, labels =,
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        print(f'Epoch [{epoch + 1}/{num_epochs}] - Step [{step + 1}/{len(train_loader)}] - Loss: {loss.item()}')
    print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {loss.item()}')

So, it is still possible to transfer data to the GPU for high-speed computation, leveraging Apple’s MPS acceleration. However, it may be slightly slower compared to loading directly onto the GPU. I suspect that the library might not have been updated for this, so if anyone encounters a similar issue, they can refer to this solution.

1 Like

you can check my solution bro.

1 Like

@Viet_Pham - LoL, I’ve been doing something similar to your solution, but the data is tokenized text that’s dependent on torchtext - and it blows up somewhere deep inside the API code on the first call to the equivalent of your enumerate(train_loader), with the message:

“RuntimeError: Expected a ‘mps:0’ generator device but found ‘mps’” (really, torchtext?)

This is the generator it’s complaining about:
generator = torch.Generator(device='mps:0')
Likewise, every tensor and module that can be put on ‘mps:0’ is mapped, and the first step at the entry point of my code is the usually-reliable:
(and, yes, i’ve tried every permutation of specifying the device as ‘mps’ and or passing a device object rather than a str)

FYI, on everything else I’ve tried, my high-end M2U ‘cpu’ is about 4x-5x faster than a T4 gpu, and ‘mps:0’ is about 3x-4x faster than ‘cpu’.

In my case, I think the solution is to remediate all traces of torchtext from my code. That may sound harsh, but check out the doc (link below). It begins with a full-page warning message, which is never a sign of reliability …