Device = "mps" is producing Nan weights in nn.Embedding

I have an NLP model that trains fine in the following contexts:

  • Windows11 + CPU
  • Windows11 + CUDA
  • Ubuntu20.04 + CPU
  • Ubuntu20.04 + CUDA
  • macOS12.5 + CPU

However, my attempts to run the same model using “mps” as the device are resulting in unexpected behavior: the nn.Embedding layers in my model are being initialized but then the weights quickly train to Nan values. There is no specific error that occurs - the loss just never improves. Since the model trains fine if I simply change my device to “cpu”, I believe this is likely an issue with my virtual environment setup. Does anyone have insight on what I am doing wrong?


Environment Setup Information
To setup a M1 compatible environment I used the following commands:

CONDA_SUBDIR=osx-arm64 conda create -n test_environment python=3.9 -c conda-forge

conda env config vars set CONDA_SUBDIR=osx-arm64

pip3 install torch torchvision torchaudio


Output from Simple Sanity Checks
These sanity checks seems to suggest the environment should work:

import torch
import platform

print(f'Platform: {platform.platform()}')
print(f'torch.has_mps: {torch.has_mps}')
print(f'MPS is available: {torch.backends.mps.is_available()}')
print(f'Pytorch was built with MPS: {torch.backends.mps.is_built()}')

Platform: macOS-12.5-arm64-arm-64bit
torch.has_mps: True
MPS is available: True
Pytorch was built with MPS: True

1 Like

Here is some additional debugging context for this problem. I’m arbitrarily examining the weights during the 5th mini-batch in my first epoch.

With device = “cpu”.
self.pos_embedding.weight produces:

Parameter containing:
tensor([[ 0.0549, -0.0284, -0.1278, …, -0.0946, -0.0928, -0.0088],
[ 0.0771, -0.0170, -0.0525, …, 0.1236, 0.0622, 0.0237],
[-0.0847, -0.0200, -0.1097, …, -0.0541, 0.0832, 0.0280],
…,
[ 0.1198, -0.0781, -0.0658, …, -0.0878, 0.0403, 0.1020],
[ 0.1255, 0.0605, 0.0063, …, -0.0283, -0.0824, -0.1120],
[ 0.0942, -0.0514, -0.0760, …, -0.0519, 0.0670, -0.0561]],
requires_grad=True)

With device = “mps”.
self.pos_embedding.weight produces:

Parameter containing:
tensor([[ nan, nan, nan, …, nan, nan, nan],
[ nan, nan, nan, …, nan, nan, nan],
[ nan, nan, nan, …, nan, nan, nan],
…,
[-0.1200, -0.0589, -0.0015, …, -0.0278, 0.1136, 0.0352],
[-0.0968, 0.0428, -0.0388, …, -0.1297, -0.0833, -0.1068],
[ 0.1078, -0.0055, 0.0330, …, -0.0212, -0.0376, 0.1046]],
device=‘mps:0’, requires_grad=True)

Please look: Incorrect tensor conversion to m1 MPS. · Issue #83015 · pytorch/pytorch · GitHub

Thank you for the response. Unfortunately, all my .to(device) calls are already using the default non_blocking = False.

I’ll play around a bit and see if I can extract a minimal example of the error I’m getting.