Hello everyone,
I have been experimenting with Flex Attention when I ran into a problem that makes Flex Attention useless for my use case:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
I require applying a custom attention mask (not the causal one above) that is computed based on the specific sequence of tokens themselves (that is how my custom architecture works). So in a batch of 4 sequences, I might have 4 different masks, each depending on the sequence itself. I compute the masks outside this function, and store them in a tensor, and because this mask_mod function can read any tensor on the device, I can have a custom mask based on the sequence.
But…
I initially assumed that “b” is the batch index, but its not.
Without knowing which batch I am in, I cannot compute my own mask for each sequence. I am left with applying the same mask for the entire batch, and thus I cannot use Flex Attention to implement my architecture at scale, that is, using vLLM, with large batches.
I got quite far in patching the vLLM code and using the newly added Flex Attention backend. But I run into that same issue. vLLM works fenominaly well on batches of sequences. And because of the flaw with Flex Attention (not being able to set a mask for each sequence), I am restricted to maybe doing one sequence at a time in vLLM.
Could someone help me confirm that this indeed is a limitation that exists, and if so, will there be plans to fix this?
Because I think if indeed I am correct, it makes Flex Attention not very useful for a wide range of architectures. We want to be able to have a custom attention mask, but not just “blind” attention patterns like checkerboard or sliding windows. Those patterns seldom make any difference. What is much more interesting are patterns based on the content of the sequence themselves, which is what I am trying to achieve.
Thank you!