Optim.step() is significantly SLOW on MPS

The first screen shot shows the simple training loop with code details, which is running on CPU. It makes sense that loss.back() is low on CPU. The second one is running on MPS, and loss.back() is faster than the first one, which makes sense too.

However, if we look at optim.step(), which should be a very simple step, it is way too slow, especially on MPS, which is nonsense.
I tried :
export PYTORCH_ENABLE_MPS_FALLBACK=0
export PYTORCH_ENABLE_MPS_FALLBACK=1
torch.mps.synchronize()
None of the above helps the situation.