Is there a way to implement RoPE around `nn.MultiheadAttention` somehow?

I want to implement Rotary Position Embeddings in PyTorch, however it seems like they need to be applied to the output of linear layers before scaled dot-product attention is computed (this is unlike sinusoidal positional encoding, which is applied to word embeddings directly).

I’m wondering if there is still a way implement Rotary Position Embeddings in a way that works with nn.MultiheadAttention and doesn’t require reimplementing multi-head attention.

The reason I want to do this is because nn.MultiheadAttention will likely perform better than my implementation since it now uses FlashAttention, plus it might have some other optimizations I’m not aware of.

Was working on a PaLM model and using lucidrain’s Pytorch implementation. This makes use of a rotary embedding between a LayerNorm and Linear as part of a transformer block. So in this implementation, it seems to be before the Linear and attention layers.

You can find that code here:

If I’m reading the code correctly, line 209 computes outputs of those pre-attention linear layers, so rotary embeddings are applied after linear layers and before attention like in other implementations.

The answer so far seems to be “no”, but as it turns out I can just use torch.nn.functional.scaled_dot_product_attention to run efficient implementations of SDPA in my custom implementation of multi-head attention, so I guess it makes this question irrelevant. Not sure if I’m losing any performance by not using nn.Transformer though.

RoPE-ing the q and k inputs to torch.nn.functional.scaled_dot_product_attention is certainly possible, but I think, if one reads the original RoPE-paper ( carefully, they note in eq. (19) the the denominator should not get rotated, to prevent accidental 0-divisions. I am not sure if torch.nn.functional.scaled_dot_product_attention already handles this somehow. So if one wants to stick to the original paper, there is currently (as far as I managed to understand) no other way, than “inserting” the rotations by re-implementing the SDPA with RoPE. Unfortunately then the layer lacks the enhancements through the optimized attention versions except one implements a custom kernel. I have set up a little demo containing the described approach on my github. Feedback welcome! :)

Hey @JannisZeller
Came across your comment here as I was implementing RoPE in PyTorch.

Just wanted to discuss on eq. 19 in the paper: According to my understanding, not rotating q and k for the denominator calculation is for linear self attention only and not for SDPA. Also, went through some implementations by
hugging face and they seem to use softwax as is on the attention scores calculated using rotated q and k implying that q and k in the denominator of the SDPA formula are rotated too.

Hello @srishti-git1110. Thanks for pointing that out. I did not notice a differentiation between linear attention and SDPA, thanks for pointing that out! Still I guess / hope, that I suggested a correct implementation of eq. 19 as it is.