With torch.compile(mode="max-autotune")

I have tested torch.compile(“max-autotune”) and I am amazed! It’s awesome.

I have found though it doesn’t work with Flash Attention. I have to use functions to get it to work on certain parts of code. It would be awesome if we could do this:

with torch.compile(mode="...")
    code that compiles...

code that doesn't compile

with torch.compile(mode="different mode")
    code that needs different mode

then we can hyper optimize without new functions for each group!

Any ideas? I don’t know how hard that is because I’m not a python expert, mostly just dinking with it, but it’s pretty cool!

Overall great work on it. Anybody got any ideas?

It should work seamlessly with torch.nn.functional.scaled_dot_product_attention — PyTorch 2.1 documentation just make sure to enable the flash attention flag