JITed nn.TransformerDecoderLayer runs significantly slower than in eager mode

I was using nn.TransformerDecoderLayer to do an experiment about dynamic shape compilation with my Apple MPS. However, by simply running nn.TransformerDecoderLayer (only one layer taking the target and memory tensors) without torch.compile and with torch.compile, I got the following result.

PyTorch 2.10.0, device=mps, dtype=float32, compile_mode=default
shape: tgt=(8, 128, 512), memory=(8, 256, 512)
config: nhead=8, dim_feedforward=2048, dropout=0.0, repeats=100, warmup=20
eager: 4.950 ms/iter
compile: 30.413 ms/iter

It shows that running the decoder layer in eager mode is significantly faster than in JIT mode. Is this because these are some special optimizations for the decoder layer running in eager mode?

torch.compile is slower on MPS because Inductor generates Triton kernels — and there is no Triton for Metal. You get the tracing overhead without the kernel fusion payoff. You can confirm with torch._dynamo.explain(model)(inputs) — check the fused ops count; it will be near zero on MPS.

Also, Apple Silicon’s unified memory means the CPU-GPU transfer savings that make compile worthwhile on CUDA basically do not exist here.

You may follow these practical advice:

  • Stick with eager mode on MPS for now — it is genuinely faster, not a compromise

  • Set PYTORCH_MPS_FALLBACK_POLICY=warn to see which ops are silently falling back to CPU

  • If you need compiled inference speed on Apple hardware, export via coremltools to CoreML -that is the Metal-native path and will outperform torch.compile on MPS significantly

MPS Inductor support is actively improving but not at CUDA parity as of PyTorch 2.11.