Macbook M2 Pro MPS issue

Hello,

I’ve built a Transformer from scratch according to the AIAYN paper (with some slight tweaks in LR, Optim, etc.) and I’m running into MPS issues. At the start of my training loop I’m doing:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
and I’m getting this error:
(venv-3.10) (base) bardia@Bardias-MacBook-Pro my-transformer-from-scratch % python3 tests.py
Traceback (most recent call last):
File “/Users/bardia/Desktop/my-transformer-from-scratch/tests.py”, line 678, in
transformer_train_valid()
File “/Users/bardia/Desktop/my-transformer-from-scratch/tests.py”, line 611, in transformer_train_valid
logits = transformer(src_ids, tgt_in_ids, src_pad_mask, tgt_in_pad_mask)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1784, in _call_impl
return forward_call(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/full_transformer.py”, line 25, in forward
input_embeddings = self.input_embed(inputs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1784, in _call_impl
return forward_call(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/embedding_block.py”, line 15, in forward
return self.embedding(x) * (1.0 / math.sqrt(self.d_model))
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1784, in _call_impl
return forward_call(*args, **kwargs)
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/modules/sparse.py”, line 192, in forward
return F.embedding(
File “/Users/bardia/Desktop/my-transformer-from-scratch/venv-3.10/lib/python3.10/site-packages/torch/nn/functional.py”, line 2546, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Placeholder storage has not been allocated on MPS device!
I1114 23:32:34.270000 34820 venv-3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:520] TorchDynamo attempted to trace the following frames: [
I1114 23:32:34.270000 34820 venv-3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:520]
I1114 23:32:34.270000 34820 venv-3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:520] ]
I1114 23:32:34.270000 34820 venv-3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py:811] TorchDynamo compilation metrics:
I1114 23:32:34.270000 34820 venv-3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py:811] Function, Runtimes (s)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats guard_or_defer_runtime_assert: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _inner_evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats _maybe_evaluate_static_worker: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V1114 23:32:34.271000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V1114 23:32:34.272000 34820 venv-3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:184] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
(venv-3.10) (base) bardia@Bardias-MacBook-Pro my-transformer-from-scratch %

Important note, when I do the above line with ‘cpu’ I don’t get any issues. I’ve had other issues before moving stuff to mps but that’s why I made my venv to download python at version 3.10 because I read thats the best way to download mps with the “nightshades” thing. But now I’m going back into issues at the start of my training loop.