I am implementing long context inference and noticed an odd deficiency of scaled_dot_product_attention
.
In long context inference, a prompt is processed in chunks, so that query
is shorter than key
, value
(in the doc: L < S
). I can pass attn_mask
and set is_causal=False
, but noted that in this case, often only the naive C++ implementation is supported. I’d therefore like to use the is_causal=True
case with L < S
.
But the way this is implemented, it is not really useful. If i
indexes query
, j
indexes key
, the masking is such that what is computed is inner(q_i, k_j) + (-infty) * ind(i < j)
. This means that most entries are -infty
. Say, L = 3, S = 5
. The mask matrix is:
tensor([[0., -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf],
[0., 0., 0., -inf, -inf]])
That makes no sense, why would I ever want to mask out most score entries? What is needed for inference, is that the mask is
tensor([[0., 0., 0., -inf, -inf],
[0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0.]])
This would correspond to a new chunk of size 3, whereas the KV cache has length 5 (the final 3 slots are for the new tokens).
I do wonder what is the point of this definition of the mask if is_causal=True
, given that only in this special case (so attn_mask=None
), the most efficient SDPA implementations are given.
Sure, I can just use my own implementation in this case, or send in attn_mask
(in which case the efficient SDPA implementations cannot be used), but why?
Note that for large context inference, this IS important. I might have a KV cache of size (say) S = 2 ** 16
and a blocksize for processing the prompt of size L = 2 ** 10
, in which case I’d appreciate a fast SDPA implementation.
As it stands, the case L < S
and is_causal=True
seems not useful to me for anything I’d want to do with SDPA. But if the default mask in this case was defined just a bit differently (see above, more 0s, less -infty), then that would be really useful for long context inference.