I’m using flex_attention (or scaled_dot_product_attention).
At inference time the length of the sequence can vary (but the mask is otherwise not data dependent).
When not using flex_attention I can just trim the mask to the sequence length (inference is batch size 1)
mask = dense_mask[:, :, (S - L) : S, :S]
(and if I enable capture scalars this even plays nice with torch.compile).
I understand recomputing the block attention mask is somewhat expensive. Is there a recommended pathway to trim the block attention mask at inference time that can avoid need to recompute it?