Is there a way to implement RoPE around `nn.MultiheadAttention` somehow?

I want to implement Rotary Position Embeddings in PyTorch, however it seems like they need to be applied to the output of linear layers before scaled dot-product attention is computed (this is unlike sinusoidal positional encoding, which is applied to word embeddings directly).

I’m wondering if there is still a way implement Rotary Position Embeddings in a way that works with nn.MultiheadAttention and doesn’t require reimplementing multi-head attention.

The reason I want to do this is because nn.MultiheadAttention will likely perform better than my implementation since it now uses FlashAttention, plus it might have some other optimizations I’m not aware of.

Was working on a PaLM model and using lucidrain’s Pytorch implementation. This makes use of a rotary embedding between a LayerNorm and Linear as part of a transformer block. So in this implementation, it seems to be before the Linear and attention layers.

You can find that code here:

If I’m reading the code correctly, line 209 computes outputs of those pre-attention linear layers, so rotary embeddings are applied after linear layers and before attention like in other implementations.

The answer so far seems to be “no”, but as it turns out I can just use torch.nn.functional.scaled_dot_product_attention to run efficient implementations of SDPA in my custom implementation of multi-head attention, so I guess it makes this question irrelevant. Not sure if I’m losing any performance by not using nn.Transformer though.