Correct way to avoid torch.compile recompilations

What is the correct way to use torch.compile to avoid dynamic shape recompilations or over-specialization?

min_sum = torch.compile((lambda a, b: torch.min(a, b).sum(dim = -1)))

I’m interested in passing in shapes:

  • (192, ) x (192, ) → scalar; (96, ) x (96, ) → scalar
  • (B, N, 1, 192) x (B, 1, N, 192) → (B, N, N); (B, N, 1, 96) x (B, 1, N, 96) → (B, N, N)

Can torch.compile generate kernels which would avoid allocation and computation of the intermediate (B, N, N, 192) tensor?

So basically I need two modes: a static shape mode and a dynamic shape mode (+broadcasting).

Can I express this? How to get the best speed with this?

Ideally, I’m interested both in CPU and CUDA codegen.

Also, if I call this torch.compile(lambda ...) in the code multiple times, will it share the compilation cache between all invokations?

Can I somehow force AOT pre-compilation for some set of shapes dtypes?

Thank you :slight_smile:

I created a GitHub for my naive bite at this task: min_sum 3x slower on torch.compile (torch 2.0.1, CPU) · Issue #106466 · pytorch/pytorch · GitHub