Wonders of how to use flex attention

Hi there, we may encounter an issue of using flex attention version https://download.pytorch.org/whl/nightly/cu118.
We created sliding window with
generate_sliding_window(window_size=self.window_size)
and attention output computation with:
y = torch.nn.attention.flex_attention.flex_attention(q, k, v, block_mask=block_mask).
However, when we measure overall gpu memory use and compare with manual implementation of sliding-window mask, flex attention doesn’t show improvement in running speed:


wondering how we make better use of flex attention cuz it looks really useful in utilizing sparsity.

Did you ever resolve this?