Scaled_dot_product_attention not useful for inference

I am implementing long context inference and noticed an odd deficiency of scaled_dot_product_attention.

In long context inference, a prompt is processed in chunks, so that query is shorter than key, value (in the doc: L < S). I can pass attn_mask and set is_causal=False, but noted that in this case, often only the naive C++ implementation is supported. I’d therefore like to use the is_causal=True case with L < S.

But the way this is implemented, it is not really useful. If i indexes query, j indexes key, the masking is such that what is computed is inner(q_i, k_j) + (-infty) * ind(i < j). This means that most entries are -infty. Say, L = 3, S = 5. The mask matrix is:

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf]])

That makes no sense, why would I ever want to mask out most score entries? What is needed for inference, is that the mask is

tensor([[0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

This would correspond to a new chunk of size 3, whereas the KV cache has length 5 (the final 3 slots are for the new tokens).

I do wonder what is the point of this definition of the mask if is_causal=True, given that only in this special case (so attn_mask=None), the most efficient SDPA implementations are given.

Sure, I can just use my own implementation in this case, or send in attn_mask (in which case the efficient SDPA implementations cannot be used), but why?

Note that for large context inference, this IS important. I might have a KV cache of size (say) S = 2 ** 16 and a blocksize for processing the prompt of size L = 2 ** 10, in which case I’d appreciate a fast SDPA implementation.

As it stands, the case L < S and is_causal=True seems not useful to me for anything I’d want to do with SDPA. But if the default mask in this case was defined just a bit differently (see above, more 0s, less -infty), then that would be really useful for long context inference.