I was using nn.TransformerDecoderLayer to do an experiment about dynamic shape compilation with my Apple MPS. However, by simply running nn.TransformerDecoderLayer (only one layer taking the target and memory tensors) without torch.compile and with torch.compile, I got the following result.
PyTorch 2.10.0, device=mps, dtype=float32, compile_mode=default
shape: tgt=(8, 128, 512), memory=(8, 256, 512)
config: nhead=8, dim_feedforward=2048, dropout=0.0, repeats=100, warmup=20
eager: 4.950 ms/iter
compile: 30.413 ms/iter
It shows that running the decoder layer in eager mode is significantly faster than in JIT mode. Is this because these are some special optimizations for the decoder layer running in eager mode?