What is the correct way to use torch.compile
to avoid dynamic shape recompilations or over-specialization?
min_sum = torch.compile((lambda a, b: torch.min(a, b).sum(dim = -1)))
I’m interested in passing in shapes:
- (192, ) x (192, ) → scalar; (96, ) x (96, ) → scalar
- (B, N, 1, 192) x (B, 1, N, 192) → (B, N, N); (B, N, 1, 96) x (B, 1, N, 96) → (B, N, N)
Can torch.compile generate kernels which would avoid allocation and computation of the intermediate (B, N, N, 192)
tensor?
So basically I need two modes: a static shape mode and a dynamic shape mode (+broadcasting).
Can I express this? How to get the best speed with this?
Ideally, I’m interested both in CPU and CUDA codegen.
Also, if I call this torch.compile(lambda ...)
in the code multiple times, will it share the compilation cache between all invokations?
Can I somehow force AOT pre-compilation for some set of shapes dtypes?
Thank you