How to Persist JIT Fused CUDA Kernels State for Efficient Inference on Multi-GPU Setup?

Hi, PyTorch Community!

I’m currently working on a deep learning project focused on computer vision, utilizing CNNs for inference. Our setup involves a multi-GPU environment, specifically using 8 GPUs, to handle our workload. While we’ve been able to achieve the performance we aimed for, there’s a notable challenge we’re facing: the initial “warm-up” time.

Problem Description: Our program takes approximately 50 seconds to “warm-up” every time we start it. This warm-up phase significantly impacts our overall efficiency, especially in scenarios where quick restarts are necessary. Upon investigation, we’ve identified that a substantial portion of this time is likely due to the JIT (Just-In-Time) compilation process, particularly the fusion and optimization of CUDA kernels.

Environment Details:

  • PyTorch version: 2.2.2
  • CUDA version: 12.1
  • Number of GPUs: 8
  • GPU model: NVIDIA A10G
  • Nature of the workload: Inference using deep CNNs for computer vision tasks

Main Question: Is there a way to serialize and persist the state of the JIT fused CUDA kernels to disk after the first compilation? Our goal is to do this “warm-up” process once and then reuse the optimized state in subsequent runs to eliminate or significantly reduce the warm-up time.


  1. If this serialization and persistence is possible, what are the steps to achieve it?
  2. How can we then reload this saved state from the disk before running our inference workload to bypass the JIT compilation process?

Why This Matters: Reducing or eliminating the warm-up time in our environment could improve our system’s responsiveness and overall throughput. Given the scale at which we’re operating, even small efficiency improvements can lead to significant benefits.

I’d appreciate any guidance or suggestions. If anyone has tackled similar challenges or knows of potential solutions, your input would be appreciated.

Thank you in advance!

Are you using the deprecated torch.jit.script or the newer torch.compile approach?
The latter should have use a cache, if I’m not mistaken, but adding @marksaroufim just in case I’m missing something.

In the torch.compile() world you can make warm startup times much faster with torch._inductor.config.fx_graph_cache = True. That said warm startup times in torch.compile will be significantly improved soon PT2 Core - H1 2024 Public Roadmap - PyTorch Dev Discussions

And if waiting is not an option you can instead try out AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models — PyTorch main documentation

Awesome, @ptrblck and @marksaroufim!

So, we are currently using torch.jit.script to generate TorchScript and running inference via LibTorch/C++, but thanks to you I’ve just realized that we should switch over to torch.compile() to leverage the AOTInductor introduced in PyTorch 2.2. It seems to be precisely what we need. We’ll give it a try.

So, I’ve just encountered the same issue as described in Missing Symbols When running AOTInductor example with Libtorch c++11 ABI. I suspect -D_GLIBCXX_USE_CXX11_ABI=1 needs to be passed to the AOT compiler. Is there any way to supply additional compiler options manually?

Setting aot_inductor.abi_compatible to True resolves the problem in our environment. For example:

    so_path = torch._export.aot_compile(
        # Specify the first dimension of the input x as dynamic
        dynamic_shapes={"x": {0: batch_dim}},
        # Specify the generated shared library path
        options={"aot_inductor.output_path": os.path.join(os.getcwd(), ""), "aot_inductor.abi_compatible": True},

I have also added a comment to the GitHub issue.