Flex Attention Extremely Slow

I try to use flex attention in huggingface transformer, only to find it very slow. Compared to the sdpa implementation, flex attention is about 4-5 times slower, but it does save the CUDA memory.

Here is the example code: Example benchmark for flex attention. · GitHub

My bad. See Flex Attention Extremely Slow · Issue #136261 · pytorch/pytorch · GitHub.