SLURM torch.distributed broadcast

I’m trying to reproduce the MLPerf v0.7 NVIDIA submission for BERT on a SLURM system. In doing so I encountered an error. Below I’ve included a minimal reproducible example:

test.sh:

#!/bin/env bash
#SBATCH --gpus-per-node=T4:2
#SBATCH -N 1
#SBATCH -t 0-00:05:00

cd $TMPDIR
echo "
import os
import torch
import torch.distributed

torch.distributed.init_process_group('nccl')

for var_name in ['SLURM_LOCALID', 'SLURM_PROCID', 'SLURM_NTASKS']:
    print(f'{var_name} = {os.environ.get(var_name)}')
local_rank = int(os.environ['SLURM_LOCALID'])
torch.cuda.set_device(local_rank)
seeds_tensor = torch.LongTensor(5).random_(0, 2**32 - 1).to('cuda')
torch.distributed.broadcast(seeds_tensor, 0)
print('Broadcast successful')
" > tmp.py

srun -l --mpi=none --ntasks=2 --ntasks-per-node=2 singularity exec $SLURM_SUBMIT_DIR/PyTorch-1.8.1.sif python -m torch.distributed.launch --use_env --nproc_per_node=2 tmp.py

Which I then launch with sbatch test.sh and PyTorch-1.8.1.sif is ubild from the official PyTorch docker image docker pull pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel

The output is:

0:   File "tmp.py", line 13, in <module>
0:     Traceback (most recent call last):
0: torch.distributed.broadcast(seeds_tensor, 0)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
0:   File "tmp.py", line 13, in <module>
0:     torch.distributed.broadcast(seeds_tensor, 0)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
0:     work = default_pg.broadcast([tensor], opts)
0: RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1616554786529/work/torch/lib/c10d/ProcessGroupNCCL.cpp:825, invalid usage, NCCL version 2.7.8
0: ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
0:     work = default_pg.broadcast([tensor], opts)
0: RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1616554786529/work/torch/lib/c10d/ProcessGroupNCCL.cpp:825, invalid usage, NCCL version 2.7.8
0: ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
0: Traceback (most recent call last):
0:   File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
0:     "__main__", mod_spec)
0:   File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
0:     exec(code, run_globals)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 340, in <module>
0:     main()
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 326, in main
0:     sigkill_handler(signal.SIGTERM, None)  # not coming back
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
srun: error: alvis2-04: task 0: Exited with exit code 1
0:     raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
0: subprocess.CalledProcessError: Command '['/opt/conda/bin/python', '-u', 'tmp.py']' returned non-zero exit status 1.
0: *****************************************
0: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
0: *****************************************
0: Killing subprocess 206223
0: Killing subprocess 206225
0: Traceback (most recent call last):
0:   File "tmp.py", line 6, in <module>
0:     torch.distributed.init_process_group('nccl')
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 500, in init_process_group
0:     store, rank, world_size = next(rendezvous_iterator)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/rendezvous.py", line 190, in _env_rendezvous_handler
0:     store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
0: RuntimeError: Address already in use
0: SLURM_LOCALID = 0
0: SLURM_PROCID = 0
0: SLURM_NTASKS = 2
1: SLURM_LOCALID = 1
1: SLURM_PROCID = 1
1: SLURM_NTASKS = 2
0: Traceback (most recent call last):
0:   File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
0:     "__main__", mod_spec)
0:   File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
0:     exec(code, run_globals)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 340, in <module>
0:     main()
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 326, in main
0:     sigkill_handler(signal.SIGTERM, None)  # not coming back
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
0:     raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
0: subprocess.CalledProcessError: Command '['/opt/conda/bin/python', '-u', 'tmp.py']' returned non-zero exit status 1.
0: *****************************************
srun: error: alvis2-08: task 0: Exited with exit code 1
0: Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
0: *****************************************
0: Killing subprocess 49549
0: Killing subprocess 49551
srun: Job step aborted: Waiting up to 122 seconds for job step to finish.
0: slurmstepd: error: *** STEP 111669.0 ON alvis2-08 CANCELLED AT 2021-10-11T10:23:52 DUE TO TIME LIMIT ***
slurmstepd: error: *** JOB 111669 ON alvis2-08 CANCELLED AT 2021-10-11T10:23:52 DUE TO TIME LIMIT ***
1: *****************************************

So here there are two different errors, the first error comes from torch.distributed.init_process_group('nccl') with error

0:     store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
0: RuntimeError: Address already in use

and the second error is from torch.distributed.broadcast with error

0: torch.distributed.broadcast(seeds_tensor, 0)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
0:   File "tmp.py", line 13, in <module>
0:     torch.distributed.broadcast(seeds_tensor, 0)
0:   File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
0:     work = default_pg.broadcast([tensor], opts)
0: RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1616554786529/work/torch/lib/c10d/ProcessGroupNCCL.cpp:825, invalid usage, NCCL version 2.7.8
0: ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).

Update, only the Address only in use error seem to remain if the last line is replaced with

srun -l --mpi=none --ntasks=2 --ntasks-per-node=2 singularity exec $SLURM_SUBMIT_DIR/PyTorch-1.8.1.sif python -m torch.distributed.launch --use_env tmp.py

Can you also add

print(f"MASTER_ADDR: ${os.environ['MASTER_ADDR']}")
print(f"MASTER_PORT: ${os.environ['MASTER_PORT']}")

before torch.distributed.init_process_group("nccl"), that may give some insight into what endpoint is being used. Once you verify the address and port, then check that there is not a process currently using that address/port combo, you can try instantiating a TCPStore via python command line, to verify that it works. It is likely a port conflict based on what you set your port number to be in the environment variables.