Edenzzzz
(Edenzzzz)
August 7, 2024, 9:16am
1
When using point-to-point communications, it might sometimes be desirable to schedule unrelated comm calls on different streams, e.g. in the backward of Ring attention.
Based on this example(InternEvo/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py at f2949df89c15e1c16b6d48412ae9f94122ef463d · InternLM/InternEvo · GitHub ), I profiled to see 3 NCCL streams for p2p comms matching local_dkv_comm
, local_kv_comm
and dkv_comm
However I couldn’t reproduce them elsewhere with 2 process groups.
Is it because that torch uses different streams only when there are a number of process groups?
fduwjj
(Hugo)
August 7, 2024, 6:29pm
2
IIUC, for each processGroup we assign the job to a NCCL stream. So if you want 3 streams, maybe you want to create a different PG?
Edenzzzz
(Edenzzzz)
August 12, 2024, 12:52pm
3
That’s right, Closing in favour of the answer on github
opened 05:13PM - 24 Oct 21 UTC
oncall: distributed
triaged
module: nccl
## 🚀 Feature
Make streams used for NCCL operations configurable
## Motivat… ion
I've noticed that PyTorch distributed module has introduced P2P send and receive functionality via NCCL (which is listed as "not supported" yet on the document though). However, I found possible cases where NCCL hangs when two nodes exchange tensor with NCCL P2P `send` and `recv`. Here is (possibly) a minimal working example.
```python
import torch
import torch.distributed as dist
from argparse import ArgumentParser
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--rank", type=int)
args = parser.parse_args()
local_rank, remote_rank = args.rank, 1 - args.rank
device = torch.device('cuda', local_rank)
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:30001', rank=local_rank, world_size=2)
dist.barrier()
print("Ready")
tensor_to_send = torch.empty((1000000,), device=device)
tensor_to_recv = torch.empty((1000000,), device=device)
# exchange tensors between two nodes
dist.send(tensor_to_send, remote_rank)
dist.recv(tensor_to_recv, remote_rank)
torch.cuda.synchronize() # never returns
print("Done")
```
I believe running two processes with command arguments `--rank 0` and `--rank 1` should not raise any problems. However, they will likely end up hanging at `torch.cuda.synchronize()` with 100% GPU utilization (if not, please try with larger tensors).
This is because `dist.send` and `dist.recv`, which internally call `ncclSend` and `ncclRecv`, cause deadlock. `ncclSend` blocks until the remote rank calls `ncclRecv`, and vice versa (see https://github.com/NVIDIA/nccl/issues/584). Thus, when `node 0` and `node 1` issue `dist.send` at the same time, and they will not be able to proceed to `dist.recv`. This makes issuing successive calls to `dist.send` and `dist.recv` very fragile to deadlocks.
To handle this, the NCCL documentation recommends grouping `ncclSend` and `ncclRecv` operations with with `ncclGroupStart` and `ncclGroupEnd` (`dist.P2POp` and `dist.batch_isend_irecv` in PyTorch). However, this does not always work for all cases, especially when two operations cannot be grouped. For example, when two nodes randomly send and receive tensors to each other, two nodes may simultaneously call `send` and fall into a deadlock as there is no locking mechanism here.
In my project, there are multiple nodes exchanging tensors with each other, and any node can send or receive a tensor. The sender transmits a tensor and corresponding metadata to the receiver as needed. To achieve this, metadata containing the size of the tensor is delivered to the receiver out-of-band via RPC. The receiver calls `dist.recv` upon receiving the metadata, and the receiver calls `dist.send` to send the tensor. Here, when two nodes call `dist.send` simultaneously by chance, this causes GPUs hang and it will never return from `synchronize()`.
## Pitch
IMHO, the easiest way to resolve this is using different CUDA streams for the `dist.send` and `dist.recv`. Unfortunately, it seems that there is no viable way for users to specify a specific stream for the NCCL operations. PyTorch distributed uses a dedicated stream for each NCCL connection (https://github.com/pytorch/pytorch/issues/60511), and `with torch.stream(stream)` does not work as in other CUDA operations.
I would like to have a context manager that can configure a stream used for NCCL operations something like below:
```python
# At the sender side
stream_send = dist.NCCLStream(remote_rank) # Create stream for NCCL operation
with dist.use_stream(stream_send):
# this context should not affect other CUDA operations
dist.send(tensor_to_send, remote_rank)
torch.cuda.synchronize()
# At the receiver side
stream_recv = dist.NCCLStream(remote_rank) # Create stream for NCCL operation
with dist.use_stream(stream_recv):
# this context should not affect other CUDA operations
dist.recv(tensor_to_recv, remote_rank)
torch.cuda.synchronize()
```
so that `dist.send` and `dist.recv` do not block each other.
## Alternatives
Well, we can implement some hacky locking mechanisms to prevent two nodes from calling `dist.send` at the same time. However, this would come with some overhead, and doesn't seem to be a good solution...
## Additional context
I'll be glad to know if other workarounds can be applied to my problem.
I'd like to work on this and submit a PR if there is no workaround to resolve this problem.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang