I just came across flex_attention, but I don’t really understand when and how I have to use torch.compile.
In the docs examples, I usually see something like
flex_attention_compiled = torch.compile(flex_attention),
and then later on a call of flex_attention_compiled with some custom score_mod.
But how can this work? score_mod may not even be defined when torch.compile is called, and it certainly is not set as score_mod argument then. Does this mean that the code in score_mod is not compiled?
The same applies to shapes of query, key, value arguments. Certainly, the compiled graph depends on these shapes. But at the time I am calling torch.compile, the arguments are not provided, or may not even exist yet.
I have an application (long context inference) where I need attention with different shapes of query. I’d like to compile graphs for each different shape. How can I do that?