PyTorch Chatbot Tutorial: Loss not decreasing as expected in MPS device (Apple Silicon)

After following the code in the PyTorch Chatbot Tutorial, I was not able to obtain the relatively low loss from the tutorial using a Mac with Apple Silicon. To make sure I didn’t introduce a bug, I downloaded the ipynb file from the tutorial page and only changed the device to MPS:

# USE_CUDA = torch.cuda.is_available()
# device = torch.device("cuda" if USE_CUDA else "cpu")
device = torch.device("mps")

I also set conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1 in my conda environment.

Despite running the same code as the tutorial with the same input data, the loss I obtain is around ~5.1049, while the tutorial’s is around ~2.9412. I am aware that in this setting I would not be able to reproduce the tutorial’s result, but I am surprised by the noticeable difference. Running the same ipynb file in Google colab (using CPU) gives a loss of ~2.5723.

Is the difference in loss between MPS and the tutorial/Google colab in this case due to stochasticity in PyTorch’s backend or is there something else going on?