TorchRuntimeError: Failed running call_module L__self___position_embedding_table(*(FakeTensor(..., device='mps', ...), **{}): Unhandled FakeTensor Device Propagation for aten.index_select.default, found two different devices mps:0, mps

I am trying to run a transformer model on a Macbook M2. I have installed PyTorch with MPS support.

So, I have set device = torch.device("mps")

I am following this tutorial notebook

When I am starting to run the training cell in the notebook, I face this error in this line

self.positional_embedding = self.position_embedding_table(torch.arange(T, device=self.device))  # (T,C)

where positional_embedding is declared as

self.position_embedding_table = nn.Embedding(block_size, n_embd)

Another piece of information: I am compiling the model after instantiating it.

UPDATE 1: The error does not occur when not using torch.compile

UPDATE 2: This error does not occur when using torch.compile if I run the code on CUDA device.

Why is it showing two different “mps” devices? and how to solve this issue?