Flex_attention and torch.compile

I just came across flex_attention, but I don’t really understand when and how I have to use torch.compile.

In the docs examples, I usually see something like

flex_attention_compiled = torch.compile(flex_attention),

and then later on a call of flex_attention_compiled with some custom score_mod.

But how can this work? score_mod may not even be defined when torch.compile is called, and it certainly is not set as score_mod argument then. Does this mean that the code in score_mod is not compiled?

The same applies to shapes of query, key, value arguments. Certainly, the compiled graph depends on these shapes. But at the time I am calling torch.compile, the arguments are not provided, or may not even exist yet.

I have an application (long context inference) where I need attention with different shapes of query. I’d like to compile graphs for each different shape. How can I do that?

Sorry, should have read the docs before asking this

1 Like