Does torch support custom stream for nccl commucation now?

There are many discussion and posts saying that torch can’t use custom stream for nccl communcation. Since ProcessGroupNCCL maintains its internal stream pool for nccl.

  // The CUDA streams used by NCCL kernels
  std::unordered_map<std::string, at::cuda::CUDAStream> ncclStreams_;

when we write code like below, it actually doesn’t really use comm_stream but as a dummy compute stream since nccl will make commucation wait for current stream to ensure input data ready.

a = some_compute()
with torch.cuda.stream(comm_stream):
  dist.allgather(output, a)

But when I dive into ProcessGroupNCCL. I find torch will choose nccl stream based on whether it’s an async op. If it is, torch will use internal stream from ncclStreams_ which is a map from device to cuda stream. otherwise, it will simply use current cuda stream

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
    std::vector<at::Tensor>& inputs,
    std::vector<at::Tensor>& outputs,
    Fn fn,
    PreProcess pre,
    PostProcess post,
    OpType opType,
    bool asyncOp,
    const char* profilingTitle,
    bool nanCheck) {

  // ....

  // in asyncOp=false [default] mode, we use currentStream as ncclStream
  // otherwise, we use separate ncclStream and let it sync on currentStream
  auto ncclStream = asyncOp ? ncclStreams_.at(key)
                            : at::cuda::getCurrentCUDAStream(device.index());

The two statements above seem somewhat contradictory, and I’m very curious which one is correct. Can torch actually customize the stream that NCCL will use?

1 Like

The ProcessGroupNCCL code is what is correct, when you do:

with torch.cuda.stream(comm_stream):
  dist.allgather(output, a)

The allgather will happen in the comm_stream. If no stream is specified and it is async, then it will use the internal stream.

Thanks,so torch has already supported custom NCCL streams to some extent.

I find only latest torch version can use user specified nccl stream

In torch 2.7, ProcessGroupNCCL just use its internal stream

// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);