How to use `torch.compile` with CUDA graphs when using gradient activation checkpointing

Hello! I would like to ask about using the latest torch.compile with the somewhat older torch.cuda.make_graphed_callables function to create CUDA graphs when gradient activation checkpointing and multi-GPU training is enabled.
It is currently my understanding that torch.compile reduces CUDA overhead while CUDA graphs reduce CPU overhead by fusing calls to the CUDA device. However, I have not seen any explanation or tutorial on whether they can be used together. Is it possible to capture the CUDA graph of a model that has undergone torch.compile? I do not think that there is a fundamental reason why this could not be done, though it may be technically time-consuming.
Also, activation checkpointing is a common optimization in Transformer architectures, which are also frequently trained on multiple GPUs. Is it possible to combine the torch.compile and creation of CUDA graphs with multi-GPU training, possibly only with DDP but not with FSDP?
Though these conditions may appear complicated, they are quite common in training large language models (LLMs), which have gained enormous popularity. I believe that making those optimizations would be very beneficial and I would like to thank any answers or informative discussion in advance.

If you wish to use torch.compile with CUDA graphs, the preferred method to do so would probably be via the option mode="reduce-overhead" which should use CUDA graphs according to the argument documentation:

       mode (str): Can be either "default", "reduce-overhead" or "max-autotune"
        - "default" is the default mode, which is a good balance between performance and overhead
        - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, useful for small batches
        - "max-autotune" is a mode that that leverages Triton based matrix multiplications and convolutions
        - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`

I believe torch.compile should be compatible with DDP, and compatible with FSDP as some issues were recently addressed e.g., [FSDP] `use_orig_params=True` with CPU offloading and Gradient Accumulation: RuntimeError · Issue #98494 · pytorch/pytorch · GitHub.

I tried adding torch.compile() to the sample in this blog: Accelerating PyTorch with CUDA Graphs | PyTorch
The output of nsys profile shows that kernels executed before and after the model gotten compiled are different. After the model is compiled, some kernels like triton__0d1d2d3 are executed, while there is no such kernel without the torch.compile().
I’m doubting if there is no benefit to use torch.comple() together with CUDA grahp, as in the mentioned case the walltime of graph replay is almost the same with and without torch.comple().
By the way, if using the reduce-overhead mode, the compiled model can not be used during the capture of the graph, or there will be an error:

RuntimeError: Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture. If you need this call to be captured, please file an issue. Current cudaStreamCaptureStatus: cudaStreamCaptureStatusActive

Howerer, the default mode does not encounter this error.
Hope there will be some official tutorial about using torch.compile() with CUDA graph.

I don’t believe manually capturing a compiled graph would be a supported use case and would recommend sticking to @eqy 's suggestion in using the reduce-overhead mode instead.

1 Like