Hello world not working with mps

This simple pytorch hello world script (xor_torch.py) does not converge with mps:

Any suggestions about the root cause?

PS: In general torch with mps works fine here:

import time
import torch
c = torch.rand((10000, 500)).to("mps")
d = torch.rand((500, 10000)).to("mps")
tic = time.time()
torch.matmul(c, d)
toc = time.time()
print(toc - tic)
print(c.device, "100 times faster than cpu, NICE!")


I am afraid the speed difference will be true for any GPU-like device you use. There is a small overhead related to using such hardware and so if your model is very small, that overhead is going to make the whole program a lot slower.

There is indeed an issue in the initialization sorry about that. Opened an issue here: Conversion from int to float dtype is not working on MPS device · Issue #77849 · pytorch/pytorch · GitHub

You can change your code to do the following to fix the issue (just create the Tensor on mps directly):

x = X_train = torch.tensor(x0, dtype=torch.float32, device=device)
1 Like