Hello,
Really interested in flex attention but still do not fully understand it yet. Have had great luck speed running GPT2 or ESM2 with it and want to continue using it during pretraining. As I understand it the best way to use flex attention is essentially with “batch size” 1 and flattening the entire input with particular block masks.
However, the ease of use with F.sdpa and typical pipelines with regular batched inputs make F.sdpa an attractive choice for inference via automatic flash attention, etc. What I’m wondering is if outputs of flex attention and F.sdpa are actually equivalent, or very close, and if not, what expected corrections may be possible to get training with flex attention and inference with sdpa?
Thanks,
Logan