Is there a .cuda() equivalent for MPS?

Is there a .cuda() equivalent for MPS? For example, X_test.cuda() in MPS is what? Is it automatic if the device is “mps”?

I don’t see one so yes you would need to add to() calls or make sure your tensors are instantiated on an MPS device

Alternatively something I’ve been using quite a bit is this global flag torch.set_default_device — PyTorch 2.0 documentation