DistributedDataParallel Socket Timeout

I’m attempting to utilize pytorch’s DistributedDataParallel in conjunction with Pytorch Geometric to train a GNN on multiple gpus. I am following an example similar to the one shown below

But it keeps timing out. Here is the error that Im getting.

[E socket.cpp:793] [c10d] The client socket has timed out after 30s while trying to connect to (localhost, 12355).
Traceback (most recent call last):
  File "/home/john/Documents/MachineLearning/gnn/Models/aws/gnn/LP_V2_wandb_aws_dist.py", line 594, in <module>
    run(0, world_size)
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/home/John/Documents/MachineLearning/gnn/Models/aws/gnn/LP_V2_wandb_aws_dist.py", line 496, in run
    setup(rank, world_size)
  File "/home/john/Documents/MachineLearning/gnn/Models/aws/gnn/LP_V2_wandb_aws_dist.py", line 480, in setup
    dist.init_process_group("nccl", rank=rank, world_size=world_size) #, timeout=timedelta(days=1))
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 595, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 257, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 184, in _create_c10d_store
    tcp_store = TCPStore(hostname, port, world_size, False, timeout)
TimeoutError: The client socket has timed out after 30s while trying to connect to (localhost, 12355).
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 56696) of binary: /home/john/anaconda3/envs/neo4j/bin/python
ERROR:torch.distributed.elastic.multiprocessing.errors.error_handler:no error file defined for parent, to copy child error file (/tmp/torchelastic_do_14voc/none_4pyl8pd1/attempt_0/0/error.json)
Traceback (most recent call last):
  File "/home/john/anaconda3/envs/neo4j/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==1.12.1', 'console_scripts', 'torchrun')())
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/run.py", line 761, in main
    run(args)
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run
    elastic_launch(
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
LP_V2_wandb_aws_dist.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-03-06_15:28:05
  host      : omen-30l-desktop
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 56696)
  error_file: /tmp/torchelastic_do_14voc/none_4pyl8pd1/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
      return f(*args, **kwargs)
    File "/home/john/Documents/MachineLearning/gnn/Models/aws/gnn/LP_V2_wandb_aws_dist.py", line 496, in run
      setup(rank, world_size)
    File "/home/john/Documents/MachineLearning/gnn/Models/aws/gnn/LP_V2_wandb_aws_dist.py", line 480, in setup
      dist.init_process_group("nccl", rank=rank, world_size=world_size) #, timeout=timedelta(days=1))
    File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 595, in init_process_group
      store, rank, world_size = next(rendezvous_iterator)
    File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 257, in _env_rendezvous_handler
      store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
    File "/home/john/anaconda3/envs/neo4j/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 184, in _create_c10d_store
      tcp_store = TCPStore(hostname, port, world_size, False, timeout)
  TimeoutError: The client socket has timed out after 3000s while trying to connect to (localhost, 12355).

This happens on the setup function whether Im using a single gpu or multiple gpus on an aws P3 cluster. I have tried increasing the timeout, changing NCCL_SOCKET_IFNAME, changing the backend from nccl to goo, and a number of other suggestions mentioned in various forums but nothing seems to work.

Here are environmental details:
Ubuntu: 22.04 LTS
pytorch: 1.12.1
pytorch-geometric: 2.2.0
cuda: 11.7
cuDNN: 8.5.0

Here are the relevant parts of my code:

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355' 
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size) 

def cleanup():
    dist.destroy_process_group()

def run(rank, world_size):

    #for i in range(2): #for i in range(len(aws_dict['heads'])):

    set_seed()

    free_gpu_cache()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    ....................

if __name__ == '__main__':

    world_size = torch.cuda.device_count()

    if world_size == 1:
        run(0, world_size)
    else:
        mp.spawn(run, args=(world_size,),nprocs=world_size)

    wandb.init().finish()

Please let me know if any other information would be helpful. Thank you!

For anyone needing to know this was caused by a wandb setup issue