Ddp logging into file

this is the follow up of this.
this is not urgent as it seems it is still in dev and not documented.

pytorch 1.9.0

hi,

log in ddp:

  1. when using torch.distributed.run instead of torch.distributed.launch my code freezes since i got this warning The module torch.distributed.launch is deprecated and going to be removed in future.Migrate to torch.distributed.run. also, in the doc they talked about torchrun which we are supposed to use. probably it is not ready yet. because they didnt tell how to call torchrun. probably i need to change other options. i only used --nnodes=1 --node_rank=0 --nproc_per_node=2 .
  2. in order to log into files, they mentioned --log_dir, -r, and -t. i tried: python -m torch.distributed.launch.py --log_dir logs -r 3 and also + -t 3. so the logger creates folders in logs with std and error for each process as mentioned in the doc but something else is still logging in terminal without logging in the files. the doc says that the log in files is for WORKERS. i think it is the launcher that is still logging into the terminal which is not considered as a worker i guess. warning which are important things are still go to terminal and not file. this is from the terminal which is not stored in file:
xxx/lib/python3.7/site-packages/torch/distributed/launch.py:164: 
DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
  "The module torch.distributed.launch is deprecated "

The module torch.distributed.launch is deprecated and going 
to be removed in future.Migrate to torch.distributed.run
*****************************************
Setting OMP_NUM_THREADS environment variable for 
each process to be 1 in default, to avoid your system 
 being overloaded, please further tune the variable for 
optimal performance in your application as needed. 
*****************************************
WARNING:torch.distributed.run:--use_env is 
deprecated and will be removed in future releases.
 Please read local_rank from `os.environ('LOCAL_RANK')` instead.

INFO:torch.distributed.launcher.api:Starting elastic_operator
 with launch configs:
  entrypoint       : multi_g.py
  min_nodes        : 1
  max_nodes        : 1
  nproc_per_node   : 2
  run_id           : none
  rdzv_backend     : static
  rdzv_endpoint    : 127.0.0.1:x
  rdzv_configs     : {'rank': 0, 'timeout': 900}
  max_restarts     : 3
  monitor_interval : 5
  log_dir          : logs
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log 
directory set to: logs/none_5s_gkwa5
INFO:torch.distributed.elastic.agent.server.api:[default]
 starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] 
Rendezvous'ing worker group
xxx/lib/python3.7/site-packages/torch/distributed/elastic/utils/store.py:53: 
FutureWarning: This is an experimental API and will be changed in future.

  "This is an experimental API and will be changed in future.", 
FutureWarning
INFO:torch.distributed.elastic.agent.server.api:[default] 
Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=127.0.0.1
  master_port=x
  group_rank=0
  group_world_size=1
  local_ranks=[0, 1]
  role_ranks=[0, 1]
  global_ranks=[0, 1]
  role_world_sizes=[2, 2]
  global_world_sizes=[2, 2]

INFO:torch.distributed.elastic.agent.server.api:[default]
 Starting worker group
INFO:torch.distributed.elastic.multiprocessing:
Setting worker0 reply file to: 
logs/none_5s_gkwa5/attempt_0/0/error.json
INFO:torch.distributed.elastic.multiprocessing:
Setting worker1 reply file to: 
logs/none_5s_gkwa5/attempt_0/1/error.json
INFO:torch.distributed.elastic.agent.server.api:[default]
 worker group successfully finished. 
Waiting 300 seconds for other agents to finish.
INFO:torch.distributed.elastic.agent.server.api:
Local worker group finished (SUCCEEDED). 
Waiting 300 seconds for other agents to finish
xxx/lib/python3.7/site-packages/torch/distributed/elastic/
utils/store.py:71: FutureWarning: 
This is an experimental API and will be changed in future.
  "This is an experimental API and will be changed in future.", 
FutureWarning
INFO:torch.distributed.elastic.agent.server.api:
Done waiting for other agents. Elapsed: 
0.000919342041015625 seconds
{"name": "torchelastic.worker.status.SUCCEEDED", 
"source": "WORKER", "timestamp": 0, "metadata":
 {"run_id": "none", "global_rank": 0, "group_rank": 
0, "worker_id": "888", "role": "default",
 "hostname": "xxx", "state": "SUCCEEDED", 
"total_run_time": 10, "rdzv_backend": "static",
 "raw_error": null, "metadata": 
"{\"group_world_size\": 1, \"entry_point\":
 \"python\", \"local_rank\": [0], \"role_rank\": 
[0], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED",
 "source": "WORKER", "timestamp": 0, "metadata": 
{"run_id": "none", "global_rank": 1, "group_rank": 0,
 "worker_id": "889", "role": "default", "hostname": 
"xxx", "state": "SUCCEEDED", "total_run_time": 10, 
"rdzv_backend": "static", "raw_error": null, "metadata": 
"{\"group_world_size\": 1, \"entry_point\": \"python\", 
\"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}",
 "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED", "source": 
"AGENT", "timestamp": 0, "metadata": {"run_id": "none", 
"global_rank": null, "group_rank": 0, "worker_id": null, "role": 
"default", "hostname": "xxx", "state": "SUCCEEDED", 
"total_run_time": 10, "rdzv_backend": "static", "raw_error":
null, "metadata": "{\"group_world_size\": 1, \"entry_point\":
 \"python\"}", "agent_restarts": 0}}

files logs/none_5s_gkwa5/attempt_0/0/error.json are not created!!!

again, this is still under dev i guess. not ready yet but it is in master.

minimal code multigp.py from this:

import argparse

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def spmd_main(local_world_size, local_rank):
    # These are the parameters used to initialize the process group
    np.random.seed(0)

    dist.init_process_group(backend="nccl")
    demo_basic(local_world_size, local_rank)

    # Tear down the process group
    dist.destroy_process_group()


def demo_basic(local_world_size, local_rank):

    # setup devices for this process. For local_world_size = 2, num_gpus = 8,
    # rank 0 uses GPUs [0, 1, 2, 3] and
    # rank 1 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // local_world_size
    device_ids = list(range(local_rank * n, (local_rank + 1) * n))

    model = ToyModel().cuda(device_ids[0])
    ddp_model = DDP(model, device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--local_world_size", type=int, default=1)
    args = parser.parse_args()
    spmd_main(args.local_world_size, args.local_rank)

bash to run it with 2 gpus:

python -m torch.distributed.launch \
       --log_dir logs \
       -r 3 \
       --nnodes=1 \
       --node_rank=0 \
       --nproc_per_node=2 \
       multigp.py \
       --local_world_size=2

thanks

cc @cbalioglu Was wondering if you could look into this?

Hey Soufiane,

  1. When you say “my code freezes” do you mean that your process hangs? Unfortunately I haven’t been able to reproduce this issue. And regarding the torchrun script; you are most likely reading the docs of the master branch. torchrun will be part of v1.10 and is simply an alias to python -m torch.distributed.run.

  2. There are some known issues with the launcher scripts in v1.9. I was able to reproduce your problem with stderr not showing up in the log_dir. In the meantime we have fixed most of the issues and a patch release v1.9.1 will be released very soon. Using the latest nightly build, I verified that your script works as expected with proper log output. If this issue is blocking you, I suggest temporarily using a nightly build and migrating to v1.9.1 in a few days once released.

Cheers,
Can

hi Can,

yes, by simply replacing torch.distributed.launch, which works fine, with torch.distributed.run the code hangs. in the example i provided above, i removed torch.distributed.barrier that used to test something. i think it is the reason for hanging. because when i type ctrl+c to stop the process, the last line printed in the stack was sleeping. it is like the process was waiting for something. and my guess it is that the process was stuck in the barrier. the same code works fine with launch. i cant run anything right now due to a power outage in our servers. please, give me some time to provide a full example. for torchrun, that explains it why an error was thrown for not recognizing it because i am using torch 1.9.0.

this is good news. thanks. it is not urgent for me right now. but, reading these logs was essential to find the cause of an issue where hints were buried in the first logs that were printed in terminal. i’ll wait for the next release. probably, it could be helpful for others to add this aspect in the doc of 1.9.0. for example, that ddp will turn off multi-threading… unless OMP_NUM_THREADS is explicitly set > 1; this was one of the warnings that i missed because the printing on terminal was fast, and mixed with my own logger.

thanks

so, i did run the code again.

  1. the observed hanging has nothing to do with torch.distributed.barrier but it seems the nccl usage based on the error log. because you succeeded to run the code above with 1.9.0 using run, it may have something to do with my system.
  2. the code above/below works fine when using torch.distributed.launch but does not work with torch.distributed.run. not sure if this has something to do with my installation. i provide below the code and requirements for my environment. i use 2 tesla p100 gpus for the test. i removed the logging arguments so i can copy all the logs as once.
  3. *.json log files are never created with our without explicit request to log into file using launch or run. it is weird because often the log says something like this: [INFO] 2021-09-09 17:06:24,434 __init__: Setting worker0 reply file to: /tmp/torchelastic_ji82kwgj/none_v3x8ezxd/attempt_0/0/error.json. but i dont know what the purpose of these log files. probably there were never needed so they were never created.
  4. this is the error with run about nccl:
Traceback (most recent call last):
File "multig.py", line 60, in <module>
spmd_main(args.local_world_size, args.local_rank)
File "multig.py", line 28, in spmd_main
demo_basic(local_world_size, local_rank)
File "multig.py", line 43, in demo_basic
Traceback (most recent call last):
File "multig.py", line 60, in <module>
ddp_model = DDP(model, device_ids)
File
"xxx/lib/python3.7/site-packages/torch/nn/parallel/distributed.py",
line 496, in __init__
dist._verify_model_across_ranks(self.process_group, parameters)
RuntimeError: NCCL error in:
/opt/conda/conda-bld/pytorch_1623448265233/work/torch/lib/c10d/ProcessGroupNCCL.cpp:911, invalid
usage, NCCL version 2.7.8
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops,
too many collectives at once, mixing streams in a group, etc).
spmd_main(args.local_world_size, args.local_rank)
File "multig.py", line 28, in spmd_main
demo_basic(local_world_size, local_rank)
File "multig.py", line 43, in demo_basic
  1. please check the run.sh if i used correctly distributed.run in term of arguments.

code multig.py:

import argparse

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def spmd_main(local_world_size, local_rank):
    # These are the parameters used to initialize the process group
    np.random.seed(0)

    dist.init_process_group(backend="nccl")
    demo_basic(local_world_size, local_rank)

    # Tear down the process group
    dist.destroy_process_group()


def demo_basic(local_world_size, local_rank):

    # setup devices for this process. For local_world_size = 2, num_gpus = 8,
    # rank 0 uses GPUs [0, 1, 2, 3] and
    # rank 1 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // local_world_size
    device_ids = list(range(local_rank * n, (local_rank + 1) * n))

    model = ToyModel().cuda(device_ids[0])
    ddp_model = DDP(model, device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--local_world_size", type=int, default=1)
    args = parser.parse_args()
    spmd_main(args.local_world_size, args.local_rank)

bash run.sh:

#!/usr/bin/env bash

# activate conda env here.

# ==============================================================================
cudaid=$1
export CUDA_VISIBLE_DEVICES=$cudaid


python -m torch.distributed.launch \
       --nnodes=1 \
       --node_rank=0 \
       --nproc_per_node=2 \
       multi_g2.py \
       --local_world_size=2

output with torch.distributed.launch when using ./run.sh 0,1. the job finished properly:

xxx/lib/python3.7/site-packages/torch/distributed/launch.py:164:
 DeprecationWarning: 
The 'warn' method is deprecated, use 'warning' instead
  "The module torch.distributed.launch is deprecated "
The module torch.distributed.launch is deprecated and 
going to be removed in future.Migrate to torch.distributed.run
*****************************************
Setting OMP_NUM_THREADS environment variable for each process
 to be 1 in default, to avoid your system being overloaded, please
 further tune the variable for optimal performance in your application 
as needed. 
*****************************************
WARNING:torch.distributed.run:--use_env is deprecated and will be
 removed in future releases.
 Please read local_rank from `os.environ('LOCAL_RANK')` instead.
INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : multi_g2.py
  min_nodes        : 1
  max_nodes        : 1
  nproc_per_node   : 2
  run_id           : none
  rdzv_backend     : static
  rdzv_endpoint    : 127.0.0.1:x
  rdzv_configs     : {'rank': 0, 'timeout': 900}
  max_restarts     : 3
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log 
directory set to: /tmp/torchelastic_db1cckbb/none_y8sql577
INFO:torch.distributed.elastic.agent.server.api:
[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:
[default] Rendezvous'ing worker group
xxx/lib/python3.7/site-packages/torch/distributed/elastic/utils/store.py:53:
 FutureWarning: This is an experimental API and will be changed in future.
  "This is an experimental API and will be changed in future.", FutureWarning
INFO:torch.distributed.elastic.agent.server.api:
[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=127.0.0.1
  master_port=xxx
  group_rank=0
  group_world_size=1
  local_ranks=[0, 1]
  role_ranks=[0, 1]
  global_ranks=[0, 1]
  role_world_sizes=[2, 2]
  global_world_sizes=[2, 2]

INFO:torch.distributed.elastic.agent.server.api
:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:
Setting worker0 reply file to: 
/tmp/torchelastic_db1cckbb/none_y8sql577/attempt_0/0/error.json
INFO:torch.distributed.elastic.multiprocessing:
Setting worker1 reply file to:
 /tmp/torchelastic_db1cckbb/none_y8sql577/attempt_0/1/error.json
INFO:torch.distributed.elastic.agent.server.api:
[default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
INFO:torch.distributed.elastic.agent.server.api:
Local worker group finished (SUCCEEDED). Waiting 300 seconds for other agents to finish
xxx/lib/python3.7/site-packages/torch/distributed/elastic/utils/store.py:71: 
FutureWarning: This is an experimental API and will be changed in future.
  "This is an experimental API and will be changed in future.", FutureWarning
INFO:torch.distributed.elastic.agent.server.api:Done waiting for other agents. Elapsed: 0.0006513595581054688 seconds
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "none", "global_rank": 0, "group_rank": 0, "worker_id": "21051", "role": "default", "hostname": "xxx", "state": "SUCCEEDED", "total_run_time": 10, "rdzv_backend": "static", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"python\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "none", "global_rank": 1, "group_rank": 0, "worker_id": "21053", "role": "default", "hostname": "xxx", "state": "SUCCEEDED", "total_run_time": 10, "rdzv_backend": "static", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"python\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}
{"name": "torchelastic.worker.status.SUCCEEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "none", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "default", "hostname": "xxx", "state": "SUCCEEDED", "total_run_time": 10, "rdzv_backend": "static", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"python\"}", "agent_restarts": 0}}

now, output with torch.distributed.run when using ./run.sh 0,1. the job hangs:

$ ./run.sh 0,1

xxx/lib/python3.7/site-packages/torch/distributed/launch.py
[INFO] 2021-09-09 17:06:24,426 run: Running torch.distributed.run with args:
['xxx/lib/python3.7/site-packages/torch/distributed/run.py',
'--nnodes=1', '--node_rank=0', '--nproc_per_node=2', 'multig.py', '--local_world_size=2']
[INFO] 2021-09-09 17:06:24,427 run: Using nproc_per_node=2.
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your
system being overloaded, please further tune the variable for optimal performance in your
application as needed. 
*****************************************
[INFO] 2021-09-09 17:06:24,427 api: Starting elastic_operator with launch configs:
entrypoint       : multig.py
min_nodes        : 1
max_nodes        : 1
nproc_per_node   : 2
run_id           : none
rdzv_backend     : static
rdzv_endpoint    : 127.0.0.1:xxx
rdzv_configs     : {'rank': 0, 'timeout': 900}
max_restarts     : 3
monitor_interval : 5
log_dir          : None
metrics_cfg      : {}

[INFO] 2021-09-09 17:06:24,429 local_elastic_agent: log directory set to:
/tmp/torchelastic_ji82kwgj/none_v3x8ezxd
[INFO] 2021-09-09 17:06:24,429 api: [default] starting workers for entrypoint: python
[INFO] 2021-09-09 17:06:24,429 api: [default] Rendezvous'ing worker group
[INFO] 2021-09-09 17:06:24,429 static_tcp_rendezvous: Creating TCPStore as the c10d::Store
implementation

xxx/lib/python3.7/site-packages/torch/distributed/elastic/utils/store.py:53:
FutureWarning: This is an experimental API and will be changed in future.
"This is an experimental API and will be changed in future.", FutureWarning
[INFO] 2021-09-09 17:06:24,433 api: [default] Rendezvous complete for workers. Result:
restart_count=0
master_addr=127.0.0.1
master_port=same_as_above
group_rank=0
group_world_size=1
local_ranks=[0, 1]
role_ranks=[0, 1]
global_ranks=[0, 1]
role_world_sizes=[2, 2]
global_world_sizes=[2, 2]

[INFO] 2021-09-09 17:06:24,433 api: [default] Starting worker group
[INFO] 2021-09-09 17:06:24,434 __init__: Setting worker0 reply file to:
/tmp/torchelastic_ji82kwgj/none_v3x8ezxd/attempt_0/0/error.json
[INFO] 2021-09-09 17:06:24,434 __init__: Setting worker1 reply file to:
/tmp/torchelastic_ji82kwgj/none_v3x8ezxd/attempt_0/1/error.json
Traceback (most recent call last):
File "multig.py", line 60, in <module>
spmd_main(args.local_world_size, args.local_rank)
File "multig.py", line 28, in spmd_main
demo_basic(local_world_size, local_rank)
File "multig.py", line 43, in demo_basic
Traceback (most recent call last):
File "multig.py", line 60, in <module>
ddp_model = DDP(model, device_ids)
File
"xxx/lib/python3.7/site-packages/torch/nn/parallel/distributed.py",
line 496, in __init__
dist._verify_model_across_ranks(self.process_group, parameters)
RuntimeError: NCCL error in:
/opt/conda/conda-bld/pytorch_1623448265233/work/torch/lib/c10d/ProcessGroupNCCL.cpp:911, invalid
usage, NCCL version 2.7.8
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops,
too many collectives at once, mixing streams in a group, etc).
spmd_main(args.local_world_size, args.local_rank)
File "multig.py", line 28, in spmd_main
demo_basic(local_world_size, local_rank)
File "multig.py", line 43, in demo_basic
ddp_model = DDP(model, device_ids)
File
"xxx/lib/python3.7/site-packages/torch/nn/parallel/distributed.py",
line 496, in __init__
dist._verify_model_across_ranks(self.process_group, parameters)
RuntimeError: NCCL error in:
/opt/conda/conda-bld/pytorch_1623448265233/work/torch/lib/c10d/ProcessGroupNCCL.cpp:911, invalid
usage, NCCL version 2.7.8
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops,
too many collectives at once, mixing streams in a group, etc).
[ERROR] 2021-09-09 17:06:34,499 api: failed (exitcode: 1) local_rank: 0 (pid: 26925) of binary:
xxx/bin/python
[ERROR] 2021-09-09 17:06:34,499 local_elastic_agent: [default] Worker group failed
[INFO] 2021-09-09 17:06:34,499 api: [default] Worker group FAILED. 3/3 attempts left; will restart
worker group
[INFO] 2021-09-09 17:06:34,499 api: [default] Stopping worker group
[INFO] 2021-09-09 17:06:34,500 api: [default] Rendezvous'ing worker group
[INFO] 2021-09-09 17:06:34,500 static_tcp_rendezvous: Creating TCPStore as the c10d::Store
implementation
[INFO] 2021-09-09 17:06:34,501 api: [default] Rendezvous complete for workers. Result:
restart_count=1
master_addr=127.0.0.1
master_port=xx
group_rank=0
group_world_size=1
local_ranks=[0, 1]
role_ranks=[0, 1]
global_ranks=[0, 1]
role_world_sizes=[2, 2]
global_world_sizes=[2, 2]

[INFO] 2021-09-09 17:06:34,501 api: [default] Starting worker group
[INFO] 2021-09-09 17:06:34,502 __init__: Setting worker0 reply file to:
/tmp/torchelastic_ji82kwgj/none_v3x8ezxd/attempt_1/0/error.json
[INFO] 2021-09-09 17:06:34,503 __init__: Setting worker1 reply file to:
/tmp/torchelastic_ji82kwgj/none_v3x8ezxd/attempt_1/1/error.json

<HANGS INDEFINITELY>

types ctlr+c:

^CTraceback (most recent call last):
File "multig.py", line 60, in <module>
spmd_main(args.local_world_size, args.local_rank)
File "multig.py", line 27, in spmd_main
dist.init_process_group(backend="nccl")
File
"xxx/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py",
line 547, in init_process_group
_store_based_barrier(rank, store, timeout)
File
"xxx/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py",
line 207, in _store_based_barrier
time.sleep(0.01)
KeyboardInterrupt
Traceback (most recent call last):
File "xxx/lib/python3.7/runpy.py", line
193, in _run_module_as_main
"__main__", mod_spec)
File "xxx/lib/python3.7/runpy.py", line
85, in _run_code
exec(code, run_globals)
File
"xxx/lib/python3.7/site-packages/torch/distributed/run.py",
line 637, in <module>
main()
File
"xxx/lib/python3.7/site-packages/torch/distributed/run.py",
line 629, in main
run(args)
File
"xxx/lib/python3.7/site-packages/torch/distributed/run.py",
line 624, in run
)(*cmd_args)
File
"xxx/lib/python3.7/site-packages/torch/distributed/launcher/api.py",
line 116, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File
"xxx/lib/python3.7/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py",
line 348, in wrapper
return f(*args, **kwargs)
File
"xxx/lib/python3.7/site-packages/torch/distributed/launcher/api.py",
line 238, in launch_agent
result = agent.run()
File
"xxx/lib/python3.7/site-packages/torch/distributed/elastic/metrics/api.py",
line 125, in wrapper
result = f(*args, **kwargs)
File
"xxx/lib/python3.7/site-packages/torch/distributed/elastic/agent/server/api.py",
line 700, in run
result = self._invoke_run(role)
File
"xxx/lib/python3.7/site-packages/torch/distributed/elastic/agent/server/api.py",
line 828, in _invoke_run
time.sleep(monitor_interval)
KeyboardInterrupt

environment:

  1. info using https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py:
$ python collect_env.py 
Collecting environment information...
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.10

Python version: 3.7.9 (default, Aug 31 2020, 12:42:55)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-4.15.0-122-generic-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: 
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB

Nvidia driver version: 455.32.00
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.7.0
[pip3] numpy==1.20.1
[pip3] torch==1.9.0
[pip3] torchvision==0.10.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.74              h6bb024c_0    nvidia
[conda] efficientnet-pytorch      0.7.0                    pypi_0    pypi
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.3.0           h06a4308_520  
[conda] numpy                     1.20.1                   pypi_0    pypi
[conda] pytorch                   1.9.0           py3.7_cuda11.1_cudnn8.0.5_0    pytorch
[conda] torchvision               0.10.0               py37_cu111    pytorch

conda virtual environment:

$ pip freeze
appdirs==1.4.4
attrs==20.3.0
backcall==0.1.0
bleach==2.1.3
certifi==2021.5.30
chardet==3.0.4
compress-pickle==1.1.0
cycler==0.10.0
Cython==0.29.2
decorator==4.3.2
efficientnet-pytorch==0.7.0
entrypoints==0.2.3
future==0.16.0
html5lib==1.0.1
idna==2.10
imageio==2.4.1
importlib-metadata==3.5.0
iniconfig==1.1.1
ipykernel==4.8.2
ipython==6.5.0
ipython-genutils==0.2.0
ipywidgets==7.4.2
jedi==0.12.1
Jinja2==2.10
jsonschema==2.6.0
jupyter==1.0.0
jupyter-client==5.2.4
jupyter-console==5.2.0
jupyter-core==4.4.0
kiwisolver==1.0.1
MarkupSafe==1.1.1
matplotlib==3.0.2
mistune==0.8.4
mock==4.0.3
more-itertools==8.8.0
munch==2.5.0
nbconvert==5.3.1
nbformat==4.4.0
networkx==2.5
notebook==5.7.4
numpy==1.20.1
olefile==0.46
opencv-python==4.1.2.30
packaging==20.9
pandocfilters==1.4.2
parso==0.3.1
pexpect==4.6.0
pickleshare==0.7.5
Pillow @ file:///tmp/build/80754af9/pillow_1625655818400/work   (8.3.1)
pluggy==0.13.1
pretrainedmodels==0.7.4
prometheus-client==0.3.1
prompt-toolkit==1.0.15
protobuf==3.7.1
ptyprocess==0.6.0
py==1.10.0
pygifsicle==1.0.1
Pygments==2.3.1
pyparsing==2.3.1
pytest==6.2.2
python-dateutil==2.8.0
PyWavelets==1.1.1
PyYAML==3.13
pyzmq==17.1.2
qtconsole==4.3.1
requests==2.24.0
scikit-image==0.17.2
scikit-learn==0.20.2
scipy==1.2.1
Send2Trash==1.5.0
simplegeneric==0.8.1
six==1.12.0
terminado==0.8.1
testpath==0.3.1
texttable==1.6.2
tifffile==2020.10.1
timm==0.4.12
toml==0.10.2
torch==1.9.0
torchvision==0.10.0
tornado==5.1.1
tqdm==4.31.1
traitlets==4.3.2
typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1624965014186/work   (3.10.0.0)
urllib3==1.25.10
wcwidth==0.1.7
webencodings==0.5.1
widgetsnbextension==3.4.2
zipp==3.4.0

i removed some packages as they require compilation.

creation of virtual envirnment and install pytroch:

conda create -n env_test_issue_ddp python=3.7
conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia

let me know if you need more info.

thanks

ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops,
too many collectives at once, mixing streams in a group, etc).

This probably indicates something where we end up using NCCL incorrectly. One guess I have is that probably the environment variables for rank and world_size are not set up correctly when we use torch.distributed.run.

this is the config i used and instead of launch i used run.
any idea how to setup these variables in this case? i can try them.
my env has 2 gpus located in the same machine.
thanks