How to use `torch.compile` with CUDA graphs when using gradient activation checkpointing

Hello! I would like to ask about using the latest torch.compile with the somewhat older torch.cuda.make_graphed_callables function to create CUDA graphs when gradient activation checkpointing and multi-GPU training is enabled.
It is currently my understanding that torch.compile reduces CUDA overhead while CUDA graphs reduce CPU overhead by fusing calls to the CUDA device. However, I have not seen any explanation or tutorial on whether they can be used together. Is it possible to capture the CUDA graph of a model that has undergone torch.compile? I do not think that there is a fundamental reason why this could not be done, though it may be technically time-consuming.
Also, activation checkpointing is a common optimization in Transformer architectures, which are also frequently trained on multiple GPUs. Is it possible to combine the torch.compile and creation of CUDA graphs with multi-GPU training, possibly only with DDP but not with FSDP?
Though these conditions may appear complicated, they are quite common in training large language models (LLMs), which have gained enormous popularity. I believe that making those optimizations would be very beneficial and I would like to thank any answers or informative discussion in advance.

1 Like

If you wish to use torch.compile with CUDA graphs, the preferred method to do so would probably be via the option mode="reduce-overhead" which should use CUDA graphs according to the argument documentation:

       mode (str): Can be either "default", "reduce-overhead" or "max-autotune"
        - "default" is the default mode, which is a good balance between performance and overhead
        - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, useful for small batches
        - "max-autotune" is a mode that that leverages Triton based matrix multiplications and convolutions
        - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`

I believe torch.compile should be compatible with DDP, and compatible with FSDP as some issues were recently addressed e.g., [FSDP] `use_orig_params=True` with CPU offloading and Gradient Accumulation: RuntimeError · Issue #98494 · pytorch/pytorch · GitHub.

I tried adding torch.compile() to the sample in this blog: Accelerating PyTorch with CUDA Graphs | PyTorch
The output of nsys profile shows that kernels executed before and after the model gotten compiled are different. After the model is compiled, some kernels like triton__0d1d2d3 are executed, while there is no such kernel without the torch.compile().
I’m doubting if there is no benefit to use torch.comple() together with CUDA grahp, as in the mentioned case the walltime of graph replay is almost the same with and without torch.comple().
By the way, if using the reduce-overhead mode, the compiled model can not be used during the capture of the graph, or there will be an error:

RuntimeError: Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture. If you need this call to be captured, please file an issue. Current cudaStreamCaptureStatus: cudaStreamCaptureStatusActive

Howerer, the default mode does not encounter this error.
Hope there will be some official tutorial about using torch.compile() with CUDA graph.

I don’t believe manually capturing a compiled graph would be a supported use case and would recommend sticking to @eqy 's suggestion in using the reduce-overhead mode instead.

1 Like

@ptrblck I do not get the concept of using torch.compile and CUDA Graph. I did not find any detailed documentation of the same.

But what I understand is the following — given a plain PyTorch model (a model which does not make use of CUDA Graph on its own), we can use torch.compile to compile it, and if we pass the reduce-overhead option in mode field of torch.compile, then the compiler shall be smart enough to understand which of the regions are “CUDA Graph capture safe” and accordingly graph those regions automatically. I say automatic based on the example code in the docs (about CUDAGraph Trees Integration)

@torch.compile
def foo(x):
    # GRAPH 1
    y = x * x * x
    # graph break triggered here
    if y.sum() > 0:
        # GRAPH 2
        z = y ** y
    else:
        # GRAPH 3
        z = (y.abs() ** y.abs())
    torch._dynamo.graph_break()
    # GRAPH 4
    return z * torch.rand_like(z)

# the first run warms up each graph, which does things like CuBlas or Triton benchmarking
foo(torch.arange(0, 10, device="cuda"))
# The second run does a CUDA Graph recording, and replays it
foo(torch.arange(0, 10, device="cuda"))
# Finally we hit the optimized, CUDA Graph replay path
foo(torch.arange(0, 10, device="cuda"))

where there is no explicit use of CUDA Graph, but the comments make it look like CUDA Graphs are being used in the background automatically.

Is my understanding correct, or is it flawed?

From PyTorch doc I got the following:

mode (str) –

Can be either “default”, “reduce-overhead”, “max-autotune” or “max-autotune-no-cudagraphs”

  • ”reduce-overhead” is a mode that reduces the overhead of python with CUDA graphs, useful for small batches. Reduction of overhead can come at the cost of more memory usage, as we will cache the workspace memory required for the invocation so that we do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints to debug.
  • ”max-autotune” is a mode that leverages Triton based matrix multiplications and convolutions It enables CUDA graphs by default.

Regarding this automatic use of CUDA Graph by the compiler, I guess this is what @ezyang was hinting at in this comment here.

Also, I guess @ezyang 's comment here

If you’re willing to do a detour in making your code torch.compile’able, it might tell you about what might be the problem.

In the context of Dynamo and CUDA Graph, and the tool letting us know “what might be the problem”, I guess he is hinting at the use of TORCH_LOG=perf_hints for debug, as mentioned in the “reduce-overhead” documentation above.

Actually, I find the concept of the PyTorch compiler very interesting, and the fact that it can handle CUDA Graph (automatically?) is just provoking me to dig deeper into its details.

So, as a follow-up question, is there any detailed documentation/tutorial on the same?

Yes, your explanation is correct and torch.compile with reduce-overhead should apply CUDA Graphs automatically where applicable.
However, there might still be a few rough edges as described in this and this issue explaining convergence issues and this as well as this issue describing memory leaks.