Distributed errors with Send/Recv and NCCL

I am trying to follow this tutorial and send a tensor from one gpu to another using send and recv as described. Here is the stripped down srcipt I am using:

import os
import torch
import torch.distributed as dist


if __name__ == "__main__":

    world_size = int(os.environ["WORLD_SIZE"])
    global_rank = int(os.environ['SLURM_PROCID'])
    local_rank = int(os.environ['SLURM_LOCALID'])

    mygroup = dist.init_process_group(backend="nccl", world_size=world_size, rank=global_rank)
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    tokens = torch.zeros(1)
    tokens= tokens.to(device)
    if global_rank == 0:
        tokens += 1
        # Send the tensor to process 1
        dist.send(tensor=tokens, dst=1)
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tokens, src=0)
    print('Rank ', global_rank, ' has data ', tokens[0])

    # Print if successfully through
    print (f'I have local rank {local_rank}, global rank {global_rank}, with a world_size of {world_size}.')

and the slurm submit script:

#!/bin/bash

#Submit this script with: sbatch filename
#SBATCH --time=1:00:00   # walltime
#SBATCH --nodes=2   # number of nodes
#SBATCH --ntasks-per-node=4   # number of tasks per node
#SBATCH --job-name=gpt1   # job name
#SBATCH --qos=standard   # qos name
#SBATCH --mem=0
#SBATCH --partition=gpu   # partition namejj


echo "NODELIST="${SLURM_NODELIST}
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
export MASTER_PORT=12340
export WORLD_SIZE=8
export NCCL_DEBUG=INFO 


# Run your training script
srun python test2.py

This leads to the following error:

NODELIST=nid[001384-001385]
nid001384:22415:22415 [0] NCCL INFO Bootstrap : Using nmn0:10.100.4.96<0>
nid001384:22415:22415 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
nid001384:22415:22415 [0] NCCL INFO cudaDriverVersion 11070
NCCL version 2.14.3+cuda11.7
nid001384:22415:25205 [0] NCCL INFO NET/IB : No device found.
nid001384:22415:25205 [0] NCCL INFO NET/Socket : Using [0]nmn0:10.100.4.96<0> [1]hsn0:10.253.6.168<0> [2]hsn1:10.253.6.167<0>
nid001384:22415:25205 [0] NCCL INFO Using network Socket
nid001384:22415:25205 [0] NCCL INFO Setting affinity for GPU 0 to ffff0000,00000000,ffff0000,00000000
nid001384:22415:25205 [0] NCCL INFO Channel 00/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 01/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 02/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 03/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 04/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 05/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 06/08 :    0   1
nid001384:22415:25205 [0] NCCL INFO Channel 07/08 :    0   1
nid001384:22416:25250 [1] NCCL INFO Channel 05/0 : 1[41000] -> 0[3000] via P2P/IPC/read
nid001384:22416:25250 [1] NCCL INFO Channel 06/0 : 1[41000] -> 0[3000] via P2P/IPC/read
nid001384:22416:25250 [1] NCCL INFO Channel 07/0 : 1[41000] -> 0[3000] via P2P/IPC/read
nid001384:22416:25250 [1] NCCL INFO Connected all rings
nid001384:22416:25250 [1] NCCL INFO Connected all trees
nid001384:22416:25250 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
nid001384:22416:25250 [1] NCCL INFO 8 coll channels, 8 p2p channels, 8 p2p channels per peer
nid001384:22416:25250 [1] NCCL INFO comm 0x39049550 rank 1 nranks 2 cudaDev 1 busId 41000 - Init COMPLETE
Rank  1  has data  tensor(1., device='cuda:1')
I have local rank 1, global rank 1, with a world_size of 8.
Rank  0  has data  tensor(1., device='cuda:0')
I have local rank 0, global rank 0, with a world_size of 8.
nid001384:22416:25324 [1] NCCL INFO [Service thread] Connection closed by localRank 1
nid001384:22415:25327 [0] NCCL INFO [Service thread] Connection closed by localRank 0
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
    dist.recv(tensor=tokens, src=0)
  File "/users/jsmidt/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1338, in recv
    dist.recv(tensor=tokens, src=0)
  File "/users/jsmidt/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1338, in recv
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
Traceback (most recent call last):
  File "/users/jsmidt/AI/MPI/test2.py", line 24, in <module>
    pg.recv([tensor], src, tag).wait()
RuntimeError: [2] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0:2', but store->get('0:2') got error: Connection reset by peer. This may indicate a possible application crash on rank 0 or a network set up issue.
    dist.recv(tensor=tokens, src=0)
  File "/users/jsmidt/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1338, in recv
    dist.recv(tensor=tokens, src=0)
  File "/users/jsmidt/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1338, in recv
    dist.recv(tensor=tokens, src=0)
  File "/users/jsmidt/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1338, in recv
    pg.recv([tensor], src, tag).wait()
RuntimeError: [3] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0:3', but store->get('0:3') got error: Connection reset by peer. This may indicate a possible application crash on rank 0 or a network set up issue.

I have tried launching this with torchrun and also get errors. I fear there is something fundamentally wring with my script that is subtle and I am missing it. Thanks!

For posterity, what seemed to fix the problem was creating an init_method. Here is the new script that works with the above slurm script for me:

import os
import torch
import torch.distributed as dist
import datetime


if __name__ == "__main__":

    world_size = int(os.environ["WORLD_SIZE"])
    global_rank = int(os.environ['SLURM_PROCID'])
    local_rank = int(os.environ['SLURM_LOCALID'])
    init_str = f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}'

    mygroup = dist.init_process_group(backend="nccl", init_method=init_str, world_size=world_size, rank=global_rank)
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    tokens = torch.zeros(1)
    tokens= tokens.to(device)
    if global_rank == 0:
        tokens += 1
        # Send the tensor to process 1
        for i in range(1,world_size):
            dist.send(tensor=tokens, dst=i)
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tokens, src=0)
    print('Rank ', global_rank, ' has data ', tokens[0])

    # Print if successfully through
    print (f'I have local rank {local_rank}, global rank {global_rank}, with a world_size of {world_size}.')

1 Like