GQA support in scaled_dot_product_attention

In the docs, I see something about GQA support, allowing number_of_heads_key_value < number_of_heads_query. When I call it this way (e.g., number_of_heads_key_value == 1, number_of_heads_query > 1, I am getting an exception. I have torch==2.5.1.

Does anybody know when this will be done? It is important to support not only GQA, but also multi-head latent attention for the DeepSeek models. For inference, I do not need this and can write it myself (matmul supports broadcasting in the batch dimensions), but when I want to fine-tune such models, I’d really need FlashAttention etc., and with this feature missing, I am stuck.

Whoever is behind this implementation: The case I am mentioning here, is very relevant for DeepSeek (multi-head latent attention), and such an obvious way to save memory and maybe GPU compute. For example, if I run DeepSeek with 128 heads, the inputs can be smaller by a factor of 128.