Flex attention benchmarking

I’m seeing great variance in running flex attention, some of the runs take way too much time. This is beyond the first one or two iterations where torch.compile requires time to compile.

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
flex_attention = torch.compile(flex_attention, dynamic=True)
from torch.utils.benchmark import Timer

def run_timer_timeit(fn, _len=10):
    t = Timer(stmt="fn()", globals={"fn": fn})
    print(t.timeit(_len)) 

def run_block_mask_flex_attn():
    flex_attention(q, k, v, block_mask=block_mask).sum().backward()

def run_xformers():
    xops.memory_efficient_attention(q, k, v).sum().backward()

device="0"
S=20000
q = torch.randn(1, 32, S, 128, requires_grad=True, device='cuda:'+device)
k = torch.randn(1, 32, S, 128, requires_grad=True, device='cuda:'+device)
v = torch.randn(1, 32, S, 128, requires_grad=True, device='cuda:'+device)

run_timer_timeit(run_full_flex_attn, 100) # 4.66s
run_timer_timeit(run_xformers, 100) # 0.007 s

Result from flex attention seem extremely slow, however, while it comes from flex attention compiling, I tried running the same function multiple time, just in case the first run had the compiling run. I think most of the cases of the flex attention decently fast, but some of them might be getting stuck,
when I ran it manually

start = time.time()
run_full_flex_attn()
end = time.time()
print("Full Flex Attention time: ", end-start)

Then, I see 0.002sec.

So, I’m not sure, why sometimes the flex attention run gets stuck when the compilation should have happened only once during initial runs.