I try to use flex attention in huggingface transformer, only to find it very slow. Compared to the sdpa implementation, flex attention is about 4-5 times slower, but it does save the CUDA memory.
Here is the example code: Example benchmark for flex attention. · GitHub