Low Accuracy with MPS

This is a follow up to this topic, in which I posted the following:
"A similar issue is found when executing the sample code here: Quickstart — PyTorch Tutorials 2.0.1+cu117 documentation

Specifically in function test(), line:

correct += (pred.argmax(1) == y).type(torch.float).sum().item()

When device = ‘mps’ it always results in 10% accuracy. when device = ‘cpu’, the accuracy is as expected."

I am creating a new topic as the old one doesn’t seem to be getting any traction, and this issue may be a different bug.

I’ve been training in MPS, use exactly that line as well.

No accuracy issues. Not sure how to help without mor info.

Maybe you can add the torch, torchvision and python versions. There could be also problems with some specific operation, so maybe the NN arch is useful as well.

Actually I think instead of type(torch.float()) I use just float()

Thanks for the response.
I tried changing type(torch.float()) to float(), but that didn’t change anything; still got an accuracy of 10% with MPS on all epochs.

I am using the following:
Python: 3.11.4
PyTorch: 2.0.1
Torchvision: 0.15.2

I am using an Intel-based Mac (older AMD GPU, macOS 12.6.9) if that makes any difference. I believe it supports Metal v2.0, not 3.0…