Is there a .cuda() equivalent for MPS?

I don’t see one so yes you would need to add to() calls or make sure your tensors are instantiated on an MPS device

Alternatively something I’ve been using quite a bit is this global flag torch.set_default_device — PyTorch 2.0 documentation