I tried to run the code to split data like this:
#specify the device to use
mps_device = torch.device("mps:0")
torch.set_default_device(mps_device)
dataset = ImageFolder(root=dataset_path, transform=data_transforms)
# Define the percentage for each split
train_ratio = 0.8
test_ratio = 0.2
total_size = len(dataset)
train_size = int(train_ratio * total_size)
test_size = total_size - train_size
train_set, test_set = random_split(dataset, [train_size, test_size]) //error on this line
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)
error trace:
Traceback (most recent call last):
File “/Users/vietpham1023/Desktop/python-resource-yoga-pose/convolutional_neural_net.py”, line 85, in
train_set, test_set = random_split(dataset, [train_size, test_size])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/Users/vietpham1023/anaconda3/lib/python3.11/site-packages/torch/utils/data/dataset.py”, line 420, in random_split
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/Users/vietpham1023/anaconda3/lib/python3.11/site-packages/torch/utils/_device.py”, line 77, in torch_function
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected a ‘mps:0’ generator device but found ‘cpu’
I tried Google search but no luck.
I also tried to re-assign the generator in the stack trace:
# Set the generator device to 'mps:0'
generator = torch.Generator(device='mps:0')
# Use the generator in the randperm function
indices = randperm(sum(lengths), generator=generator).tolist()
but got this error:
RuntimeError: Expected a ‘mps:0’ generator device but found ‘mps’
continue to debug I trace to the last function in the stack trace which is:
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device #I replaced this line with: kwargs['device'] = 'mps:0' but no luck
return func(*args, **kwargs)'
also I saw in this line kwargs[‘device’] = self.device, when debugging it has self.device = ‘mps:0’
but the ‘generator’ key in kwargs dict is just ‘mps’, I tried to re-assign kwargs[‘generator’].device = self.device but got an error said: I can’t write to it.
Thanks and I appreciate any help.