Larger batch inference sizes for torch.compile yields untenable warm up times

Observations:

For a large batch inference workload, torch.compile takes a very long time to warm up, e.g. quickly goes from minutes to hours with modestly sized batches.

However, the un-compiled model can easily ingest much larger size batches.

Modeling context:

  • torch.inference_mode()
  • mixed precision context of fp16.
  • LLM with scaled dot product attention module
  • large batch inference job using multiple GPU workers

So is it correct that:

  • there is significant performance gains to be had with the compiled version
  • but their is also significant warm up time for larger batch sizes
  • the warm-up artifacts cannot easily be cached and shared as of today
  • thus is a trade-off in throughput between compiled vs not-compiling when accounting for the warmup time.

Question 1: How does the warmup time scale with batch sizes?

Question 2: Is it correct to think that smaller batch sizes should be used with torch.compile to balance the warmup time penalty and inference throughput?

Thanks,
Fred