Flex_attention inference with variable length inputs

I’m using flex_attention (or scaled_dot_product_attention).

At inference time the length of the sequence can vary (but the mask is otherwise not data dependent).

When not using flex_attention I can just trim the mask to the sequence length (inference is batch size 1)

mask = dense_mask[:, :, (S - L) : S, :S]

(and if I enable capture scalars this even plays nice with torch.compile).

I understand recomputing the block attention mask is somewhat expensive. Is there a recommended pathway to trim the block attention mask at inference time that can avoid need to recompute it?

This thread might be related! Dynamic mask block sizes during inference · Issue #109 · pytorch-labs/attention-gym · GitHub

1 Like