This requirement was specified in the original blogpost: https://pytorch.org/blog/flexattention/
“FlexAttention requires that all sequence lengths be a multiple of 128 – this will be addressed soon.”
It is not mentioned in the docs in either torch version 2.5/2.7 though: torch.nn.attention.flex_attention — PyTorch 2.5 documentation
What’s the current status? I did not see a mention of it in the recent torch releases.