Torch.compile and Type Hinting

Hi there,

I’m wondering whether using type hinting in the code helps / speeds up the compiling operation or it does not matter at all.
I think in the old jit days, that was making a difference. I wonder how that works now with torch.compile.

Any input is appreciated!

I think the answer here is no.

One of the core differences between torch.jit.script and PT2 is that:

  • torch.jit.script does static analysis of your python code to try to optimize it. Here, having the user give the static analysis infra more info via type hints can potential help with optimizing the code

  • torch.compile doesn’t do static analysis - it traces your code with example inputs, and specializes the optimized code that it generates off of all of the concrete metadata that it found when tracing the program with your particular inputs. One downside is that we might specialize and need to recompile later when you provide different inputs, but in general the idea/hope is that recompiles will happen rarely (as long as the user’s code is written in a reasonable way). The upside of tracing + specialization is that we don’t need type hints to get any extra data about the program, since we have the concrete values from your example inputs.

There are a few cases though where the user can add extra info to their program to help torch.compile. One example is if you’re using torch.compile(m, mode="reduce-overhead") and performing training: PT2 will try to guess when the start of your training loop is so it knows where to start the cudagraph, but you can explicit tell it with the torch._inductor.cudagraph_mark_step_begin() API (more info here: CUDAGraphs in Pytorch 2.0 - compiler - PyTorch Dev Discussions)

1 Like