The first screen shot shows the simple training loop with code details, which is running on CPU. It makes sense that loss.back() is low on CPU. The second one is running on MPS, and loss.back() is faster than the first one, which makes sense too.
However, if we look at optim.step(), which should be a very simple step, it is way too slow, especially on MPS, which is nonsense.
I tried :
None of the above helps the situation.