This simple pytorch hello world script (xor_torch.py) does not converge with mps:
This file has been truncated.
import torch.nn as nn
import numpy as np
device = torch.device("mps") # factor 100 slower AND erroneous !?!?
device = torch.device("cpu")
x0 = [[0, 0], [0, 1], [1, 0], [1, 1]]
y0 = [0, 1, 1, 0]
Any suggestions about the root cause?
PS: In general torch with mps works fine here:
c = torch.rand((10000, 500)).to("mps")
d = torch.rand((500, 10000)).to("mps")
tic = time.time()
toc = time.time()
print(toc - tic)
print(c.device, "100 times faster than cpu, NICE!")
I am afraid the speed difference will be true for any GPU-like device you use. There is a small overhead related to using such hardware and so if your model is very small, that overhead is going to make the whole program a lot slower.
There is indeed an issue in the initialization sorry about that. Opened an issue here:
Conversion from int to float dtype is not working on MPS device · Issue #77849 · pytorch/pytorch · GitHub
You can change your code to do the following to fix the issue (just create the Tensor on mps directly):
x = X_train = torch.tensor(x0, dtype=torch.float32, device=device)