Is there a .cuda() equivalent for MPS? For example, X_test.cuda() in MPS is what? Is it automatic if the device is “mps”?
I don’t see one so yes you would need to add
to() calls or make sure your tensors are instantiated on an MPS device
Alternatively something I’ve been using quite a bit is this global flag torch.set_default_device — PyTorch 2.0 documentation