Correct usage of torch.distributed.run (multi-node multi-gpu)

Hello,

I used to launch a multi node multi gpu code using torch.distributed.launch on two cloud servers using two different .sh script in each machine:

#machine 1 script
export NUM_NODES=2
export NUM_GPUS_PER_NODE=4
export HOST_NODE_ADDR=10.70.202.133

python -m torch.distributed.launch \
    --nproc_per_node=$NUM_GPUS_PER_NODE \
    --nnodes=$NUM_NODES \
    --node_rank=0 \
    --master_addr=$HOST_NODE_ADDR \
    --master_port=1234 \
    train.py --debug
#machine 2 script
export NUM_NODES=2
export NUM_GPUS_PER_NODE=4
export HOST_NODE_ADDR=10.70.202.133

python -m torch.distributed.launch \
    --nproc_per_node=$NUM_GPUS_PER_NODE \
    --nnodes=$NUM_NODES \
    --node_rank=1 \
    --master_addr=$HOST_NODE_ADDR \
    --master_port=1234 \
    train.py --debug

So when I started to work with PyTOrch 1.9, it says that torch.distributed.launch is deprecated and I have to migrate to torch.distributed.run. Unfortunately, there is not enough information about this module in the documentation on how it replaces torch.distributed.launch. When I try to work with the new method and I use the two new scripts in each machine:

#machine 1
export NUM_NODES=2
export NUM_GPUS_PER_NODE=4
export HOST_NODE_ADDR=10.70.202.133:1234
export JOB_ID=22641

python -m torch.distributed.run \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$NUM_GPUS_PER_NODE \
    --node_rank=0 \
    --rdzv_id=$JOB_ID \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$HOST_NODE_ADDR \
    train.py --debug
#machine 2
export NUM_NODES=2
export NUM_GPUS_PER_NODE=4
export HOST_NODE_ADDR=10.70.202.133:1234
export JOB_ID=22641

python -m torch.distributed.run \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$NUM_GPUS_PER_NODE \
    --node_rank=1 \
    --rdzv_id=$JOB_ID \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$HOST_NODE_ADDR \
    train.py --debug

I get this error:

[ERROR] 2021-07-09 19:37:35,417 error_handler: {
  "message": {
    "message": "RendezvousConnectionError: The connection to the C10d store has failed. See inner exception for details.",
    "extraInfo": {

Here’s how I setup my training script:

if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        config["RANK"] = int(os.environ["RANK"])
        config["WORLD_SIZE"] = int(os.environ["WORLD_SIZE"])
        config["GPU"] = int(os.environ["LOCAL_RANK"])
    else:
        config["DISTRIBUTED"] = False
        config["GPU"] = 1
        config["WORLD_SIZE"] = 1
        return

    config["DISTRIBUTED"] = True
    torch.cuda.set_device(config["GPU"])
    config["DIST_BACKEND"] = "nccl"
    torch.distributed.init_process_group(
        backend=config["DIST_BACKEND"],
        init_method=config["DIST_URL"],
        world_size=config["WORLD_SIZE"],
        rank=config["RANK"],
    )

and I also parse the following argument:

parser.add_argument("--local_rank", type=int)

Am I missing something?

Thanks for raising this issue! Since this seems like it could be a possible bug, or at the very least, a migration issue, can you file an issue (essentially this post) over at Issues · pytorch/pytorch · GitHub so that we can take a deeper look?

cc @cbalioglu @H-Huang @Kiuk_Chung @aivanou

Also, IIUC, torch.distributed.run should be fully backward-compatible with torch.distributed.launch. Have you tried simply dropping in torch.distributed.run with the same launch arguments, and if so what sort of issues did you hit there?

The docs for torch.distributed.launch|run needs some improvements to match the warning message. This issue is being tracked here: dist docs need an urgent serious update · Issue #60754 · pytorch/pytorch · GitHub. And most of it has been addressed in the nightly docs: torch.distributed.run (Elastic Launch) — PyTorch master documentation.

For the time being here are the things to note:

  1. torch.distributed.run does not support parsing --local_rank as cmd arguments. If your script does this, then change it to getting local rank from int(os.environ["LOCAL_RANK"]). If you can’t change the script, then stick to torch.distributed.launch for now.

  2. As @rvarm1 mentioned, torch.distributed.run's arguments are mostly backwards compatible with torch.distributed.launch (the exception is --use_env which is now set as True by default since we are planning to deprecate reading local_rank from cmd args in favor of env).

I tried torch.distributed.run with the same legacy arguments, and it works. Seems like it has a problem with the new rndvz arguments (or maybe I am not setting them up correctly).