### 🐛 Describe the bug
When running torch rpc on multiple nodes with submitit (…through slurm) I get an EOF error _even if I'm not using the gpus and I'm not making them available to RPC_.
Here's a script to reproduce:
```python
import os
from torch.distributed import rpc
import torch
import submitit
import socket
import subprocess
import time
MAX_TIME_TO_CONNECT=1000
def rpc_init_node(
rank,
rank0_ip,
tcp_port,
world_size,
):
DEVICES=[] #list(range(torch.cuda.device_count()))
os.environ["MASTER_ADDR"] = str(rank0_ip)
os.environ["MASTER_PORT"] = "29500"
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# os.environ['TP_SOCKET_IFNAME']='lo'
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16,
init_method=f"tcp://{rank0_ip}:{tcp_port}",
rpc_timeout=MAX_TIME_TO_CONNECT,
_transports=["uv"],
# Currently fails when nodes have more than 0 gpus avail,
# even when no device is made visible
devices=DEVICES,
)
print(f"init rpc on {rank}")
rpc.init_rpc(
f"NODE_{rank}",
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
world_size=world_size,
)
rpc.shutdown()
def rpc_init_master(
tcp_port,
world_size,
):
hostname = socket.gethostname()
rank0_ip = socket.gethostbyname(hostname)
DEVICES=[] # list(range(torch.cuda.device_count()))
os.environ["MASTER_ADDR"] = str(rank0_ip)
os.environ["MASTER_PORT"] = "29500"
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# os.environ['TP_SOCKET_IFNAME']='lo'
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16,
init_method=f"tcp://{rank0_ip}:{tcp_port}",
rpc_timeout=MAX_TIME_TO_CONNECT,
_transports=["uv"],
# Currently fails when nodes have more than 0 gpus avail,
# even when no device is made visible
devices=DEVICES,
)
print("init rpc on master")
rpc.init_rpc(
"TRAINER",
rank=0,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
world_size=world_size,
)
# some dummy compute
out = rpc.rpc_sync("NODE_1", torch.add, args=(torch.ones(()), torch.ones(())))
rpc.shutdown()
print("result", out)
return out.item()
if __name__ == "__main__":
slurm_conf = {
"timeout_min": 100,
"slurm_partition": "train",
"slurm_cpus_per_task": 4, # works
#"slurm_gpus_per_task": 1, "slurm_cpus_per_gpu": 8, # does not work
}
master_on_node = True
num_nodes = 2
executor = submitit.AutoExecutor(folder="log_test")
executor.update_parameters(**slurm_conf)
if not master_on_node:
hostname = socket.gethostname()
IPAddr = socket.gethostbyname(hostname)
else:
job = executor.submit(rpc_init_master, 1234, num_nodes+1)
print("job id", job.job_id)
time.sleep(2.0)
cmd=f"squeue -j {job.job_id} -o %N | tail -1"
node = subprocess.check_output(cmd, shell=True, text=True).strip()
print("node", node)
cmd=f'sinfo -n {node} -O nodeaddr | tail -1'
print(cmd)
IPAddr = subprocess.check_output(cmd, shell=True, text=True).strip()
print("IP addr:", IPAddr)
for i in range(num_nodes):
_job = executor.submit(
rpc_init_node, i + 1, IPAddr, 1234, num_nodes+1)
if not master_on_node:
out = rpc_init_master(1234, num_nodes+1)
else:
out = job.result()
print("result", out)
```
I commented the line that makes the code break if uncommented (you should comment the line above tagged with `# works`).
### What does not matter
- If you tell which devices RPC should see using `devices=list_of_device`, or `devices=[]` the effect is the same.
- If you launch things from the master node or create a master node (see script for example) the error is the same
- The code runs using multiprocessing, presumably because I'm staying on the same node (?)
I had to set the `_transport` in the TensorPipe options because I'm running on AWS and without it it's not running
Here's the error:
```
Traceback (most recent call last):
File "/data/home/vmoens/dump/dummy.py", line 109, in <module>
out = job.result()
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/core.py", line 266, in result
r = self.results()
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/core.py", line 294, in results
raise job_exception # pylint: disable=raising-bad-type
submitit.core.utils.FailedJobError: Job (task=0) failed during processing with trace:
----------------------
Traceback (most recent call last):
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/submission.py", line 54, in process_job
result = delayed.result()
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/submitit/core/utils.py", line 133, in result
self._result = self.function(*self.args, **self.kwargs)
File "/data/home/vmoens/dump/dummy.py", line 63, in rpc_init_master
rpc.init_rpc(
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 199, in init_rpc
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/__init__.py", line 234, in _init_rpc_backend
rpc_agent = backend_registry.init_backend(
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 104, in init_backend
return backend.value.init_backend_handler(*args, **kwargs)
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/backend_registry.py", line 363, in _tensorpipe_init_backend_handler
api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
return func(*args, **kwargs)
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 224, in _all_gather
rpc_sync(
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 82, in wrapper
return func(*args, **kwargs)
File "/fsx/users/vmoens/conda/envs/compile/lib/python3.9/site-packages/torch/distributed/rpc/api.py", line 809, in rpc_sync
return fut.wait()
RuntimeError: EOF: end of file (this error originated at tensorpipe/transport/uv/connection_impl.cc:132)
```
### Versions
Latest torch nightly, locally built
```
PyTorch version: 2.0.0.dev20230220+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~18.04) 9.4.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.25.0
Libc version: glibc-2.27
Python version: 3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 15:55:03) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1 /usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1 HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 1369.306
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 36608K
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+c8bfe3f548
[pip3] torch==2.0.0.dev20230220+cu118
[pip3] torchaudio==2.0.0.dev20230222+cu118
[pip3] torchrl==0.0.4+46ec988
[pip3] torchsnapshot==0.1.0
[pip3] torchvision==0.15.0.dev20230221+cu118
[conda] magma-cuda110 2.5.2 1 pytorch
[conda] mkl 2022.1.0 hc2b9512_224
[conda] mkl-include 2023.0.0 h84fe81f_26648 conda-forge
[conda] numpy 1.24.1 pypi_0 pypi
[conda] pytorch-triton 2.0.0+c8bfe3f548 pypi_0 pypi
[conda] torch 2.0.0a0+gitd677432 dev_0 <develop>
[conda] torchaudio 2.0.0.dev20230222+cu118 pypi_0 pypi
[conda] torchrl 0.0.4+46ec988 pypi_0 pypi
[conda] torchsnapshot 0.1.0 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230221+cu118 pypi_0 pypi
```
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @pietern @jjlilley @mrzzd @lw @beauby