Observations:
For a large batch inference workload, torch.compile takes a very long time to warm up, e.g. quickly goes from minutes to hours with modestly sized batches.
However, the un-compiled model can easily ingest much larger size batches.
Modeling context:
- torch.inference_mode()
- mixed precision context of fp16.
- LLM with scaled dot product attention module
- large batch inference job using multiple GPU workers
So is it correct that:
- there is significant performance gains to be had with the compiled version
- but their is also significant warm up time for larger batch sizes
- the warm-up artifacts cannot easily be cached and shared as of today
- thus is a trade-off in throughput between compiled vs not-compiling when accounting for the warmup time.
Question 1: How does the warmup time scale with batch sizes?
Question 2: Is it correct to think that smaller batch sizes should be used with torch.compile to balance the warmup time penalty and inference throughput?
Thanks,
Fred