Faster Attenion with MPS backend

If you’re running PyTorch models on Apple Silicon, I just open-sourced a custom attention operator. It wraps Apple’s scaledDotProductAttention MPS Graph operation which frequently out-performs PyTorch’s scaled_dot_product_attention with the MPS backend for sequences of 1024+ tokens.

:hammer_and_wrench: Code: GitHub - jhurt/attention-mps-torch: A custom PyTorch operator for invoking Metal Performance Shaders Graph Scaled Dot Product Attention · GitHub