What's the best way to serialize and reuse torch.compile'd functions?


In our pipeline there’re numerous light-weight but compute-intensive functions in different nested loops that can be speed-up using torch.compile.

These functions are separated by relatively complicated business logic that tends to change often, so compiling the entire nested loop is out of the question.

These functions are invoked repeatedly for thousands of times in nested loops, each time taking a handful of different input tensor shapes.

We know that compile with dynamic=True avoids recompilation, but brings additional runtime overhead. Without dynamic=True, everytime input shapes changes will ensue recompilation that brings even more runtime overhead.

Given that we only have ~ a dozen of possible input shapes, and we’re ok with a long initial compilation overhead. Is there a way to pre-compile the specialized versions of the same function, serialize them somewhere, and do runtime dispatch based on input shape?

It looks output_code flag can be used to generate python-importable code snippet, but feels quick hacky. Is there any formal serialization functionality in the roadmap?

curious about a solution as well

the formal solution for serialization of compiled code is torch.export