Pytorch cudagraph with nccl operation failed

Here is a simple example:

import torch
import torch.distributed as dist

def main():
    import os
    # Initialize the NCCL backend
    dist.init_process_group(backend='nccl')
    world_size = int(os.environ['WORLD_SIZE'])
    rank = int(os.environ['RANK'])

    # Create a tensor on the GPU
    tensor = torch.rand(10).cuda(rank)

    # Start CUDA graph capture
    stream = torch.cuda.Stream(device=f"cuda:{rank}")
    graph = torch.cuda.CUDAGraph()
    stream.synchronize()
    with torch.cuda.graph(graph, stream=stream):
        # Perform all-reduce operation
        dist.all_reduce(tensor)
    stream.synchronize()

    # Execute the graph
    graph.replay()

    print("All-reduce completed:", tensor)

if __name__ == "__main__":
    main()

Run with torchrun --nproc-per-node 8 test.py, got the following error:

torch/distributed/distributed_c10d.py", line 2050, in all_reduce
    work = group.allreduce([tensor], opts)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:219, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.18.1
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'operation not permitted when stream is capturing'

Per the nccl documentation:

Starting with NCCL 2.9, NCCL operations can be captured by CUDA Graphs.

So why does pytorch cudagraph fail to capture allreduce operation?

From the docs:

Disable DDP’s internal async error handling
Before full-backward capture, DDP must be constructed in a side-stream context
Your warmup must run at least 11 DDP-enabled eager iterations before capture.

If I have a large network, do I need to warm up the whole network for at least 11 runs, or I can just launch some small DDP/pytorch iterations to fulfill the requirement?

You need full iterations of the DDP workload you want to capture and replay.

What if I’m just using inference without grad, i.e. just using torch.distributed.all_reduce? Do I still have to run the full workload for more than 11 runs?

I would claim it wouldn’t hurt, but if you don’t need to reduce any gradients and initialize the grad buckets, you might be able to reduce the warmup iterations to 3.

These numbers are sooooo magic :rofl:

There are of course technical reasons to make sure the heuristics were able to select the fastest kernels, kernel fusions and other optimizations are done, buffers were allocated… But the true reason is of course:
1000076234

1 Like

lol is this an existing meme or you just created one for for this thread?

Yes, I made it, but I already had it ready from another occasion :wink: