nn.DataParallel gets stuck

Hello.

I’m trying to train a model on multiGPU using nn.DataParallel and the program gets stuck. (in the sense I can’t even ctrl+c to stop it). My system has 3x A100 GPUs. However, the same code works on a multi-GPU system using nn.DataParallel on system with V100 GPUs. How can I debug what’s going wrong?

I have installed pytorch and cudatoolkit using anaconda. Both are in its latest version (as of writing this post). Nvidia driver version is 465 and Cuda version is 11.0.

Output of $ nvidia-smi topo -m

        GPU0    GPU1    GPU2    CPU Affinity    NUMA Affinity
GPU0     X      SYS     SYS     0-15,32-47      0
GPU1    SYS      X      NODE    16-31,48-63     1
GPU2    SYS     NODE     X      16-31,48-63     1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Output of $nvidia-smi

$ nvidia-smi
Wed Jun 30 05:51:19 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  On   | 00000000:21:00.0 Off |                    0 |
| N/A   21C    P0    33W / 250W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   22C    P0    35W / 250W |  11615MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-PCI...  On   | 00000000:E2:00.0 Off |                    0 |
| N/A   21C    P0    32W / 250W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               

Output of $ python -c "import torch;print(torch.__version__);print(torch.cuda.is_available())"

1.8.1+cu111

True

Thank you

3 Likes

I’ve also ran into problems like this, have you tried to use nn.DistributedDataParallel instead? It appears to be recommended to use that instead of nn.DataParallel even for multi-gpu on a single node:

DataParallel — PyTorch 1.9.0 documentation

CUDA semantics — PyTorch 1.9.0 documentation

1 Like

Hi,

Thanks for the reply. I am not looking for the DistribubtedDataParallel option as it seems to me like there are overheads which exceeds my use-case. I did try now nonetheless and here is the error I get,

RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:825, invalid usage, NCCL version 2.7.8
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
    dist._broadcast_coalesced(
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:825, invalid usage, NCCL version 2.7.8
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).
Killing subprocess 1658838
Killing subprocess 1658839
Traceback (most recent call last):
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launch.py", line 340, in <module>
    main()
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launch.py", line 326, in main
    sigkill_handler(signal.SIGTERM, None)  # not coming back
  File "/home/user/.local/lib/python3.8/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
    raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)

I’m actually not convinced why DataParallel wouldn’t work. Especially when it seems to work well on other machine with same software.

Interesting, yea I never really had nn.DataParallel working for me, I often found that it just hangs (during training) it could possibly be my machine but not too sure. I’ll probably give it another go soon along with nn.DistributedDataParallel.

The legendary @ptrblck may know something though!

Hello @ptrblck , @fmassa . I will be really grateful if you can please weigh in your thoughts. Thank you

I don’t know why nn.DataParallel is not working and hanging instead. You could try to attach to the hanging process via gdb and check the backtrace to see which operation is not working.

As @Epoching said, nn.DataParallel is generally slower than nn.DistributedDataParallel and thus I never use it.
For the DDP error: make sure that simple examples are working first and then try to check, if you are using DDP in a wrong way.
E.g. use this script:

# script.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
import types
import argparse
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 2, bias=False)
        self.drop = nn.Dropout(p=0.5)
    def forward(self, x):
        print('fc1.weight {}'.format(self.fc1.weight))
        x = self.fc1(x)
        x = self.drop(x)
        print('x {}'.format(x))
        return x
def main():
    parser = argparse.ArgumentParser(description='fdsa')
    parser.add_argument("--local_rank", default=0, type=int)
    args = parser.parse_args()
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    model = MyModel().to(args.gpu)
    model = DistributedDataParallel(
        model,
        device_ids=[args.gpu],
        output_device=args.local_rank,
    )
    for i in range(2):
        model.zero_grad()
        x = torch.randn(1, 2, device=args.gpu)
        out = model(x)
        print('iter {}, out {}'.format(i, out))
        out.mean().backward()
if __name__ == "__main__":
    main()

and execute it via:

python -m torch.distributed.launch --nproc_per_node=2 script.py 
5 Likes

Hello @ptrblck , sorry for my late answer.

It seems to be the same issue as with nn.DataParallel The process gets stuck and the volatile GPU usage seems to be 100%, attached the screenshot of nvidia-smi

Are you already using the latest PyTorch release and if not, could you update to it?

Hello @ptrblck , I updated torch to the nightly version using conda. The code still gets stuck at first iteration. Just as before, nvidia-smi output shows 100% gpu usage but the process doesn’t move ahead. Attaching the output of the script below if it’d be of any help.

The module torch.distributed.launch is deprecated and going to be removed in future.Migrate to torch.distributed.run
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
WARNING:torch.distributed.run:--use_env is deprecated and will be removed in future releases.
 Please read local_rank from `os.environ('LOCAL_RANK')` instead.
INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : gpu_test_script.py
  min_nodes        : 1
  max_nodes        : 1
  nproc_per_node   : 2
  run_id           : none
  rdzv_backend     : static
  rdzv_endpoint    : 127.0.0.1:29500
  rdzv_configs     : {'rank': 0, 'timeout': 900}
  max_restarts     : 3
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_t4bhfk0o/none_49jop45c
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=127.0.0.1
  master_port=29500
  group_rank=0
  group_world_size=1
  local_ranks=[0, 1]
  role_ranks=[0, 1]
  global_ranks=[0, 1]
  role_world_sizes=[2, 2]
  global_world_sizes=[2, 2]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_t4bhfk0o/none_49jop45c/attempt_0/0/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker1 reply file to: /tmp/torchelastic_t4bhfk0o/none_49jop45c/attempt_0/1/error.json
fc1.weight Parameter containing:
tensor([[-0.6771, -0.5446],
        [-0.2001, -0.0258]], device='cuda:0', requires_grad=True)
x tensor([[0., 0.]], device='cuda:0', grad_fn=<FusedDropoutBackward>)
iter 0, out tensor([[0., 0.]], device='cuda:0', grad_fn=<_DDPSinkBackward>)

Also, torch version,

Python 3.8.10 (default, Jun  4 2021, 15:09:15)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.23.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

In [2]: torch.__version__
Out[2]: '1.10.0.dev20210706'

Thanks

Hello @ptrblck ping. A brief reminder on this thread. It’s sad that I cannot use all the GPU resources that I have access to. I’ll be very grateful if you can help me figure out what is the issue with my setup. Thank you once again.

AS I cannot reproduce the issue, my best recommendation would be to update the driver.
In fact I’m also debugging other hangs on 465.19.01 so your hang would indicate a real issue in this particular driver release. Let me know, if updating (or downgrading) would work.

Hello @ptrblck ,

I updated drivers to 470. The issue persists, attaching below the output of watch nvidia-smi,

Every 2.0s: nvidia-smi                                                                                                                                                ampere.lix.polytechnique.fr: Thu Jul 22 09:21:53 2021

Thu Jul 22 09:21:53 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  On   | 00000000:21:00.0 Off |                    0 |
| N/A   25C    P0    59W / 250W |   1768MiB / 40536MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   25C    P0    58W / 250W |   1502MiB / 40536MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-PCI...  On   | 00000000:E2:00.0 Off |                    0 |
| N/A   20C    P0    32W / 250W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      5629      C   .../envs/torch1p9/bin/python     1765MiB |
|    1   N/A  N/A      5630      C   .../envs/torch1p9/bin/python     1499MiB |
+-----------------------------------------------------------------------------+

I ran the same script which you suggested above. Using the latest release of pytorch. It gets stuck once again, but this time at least, ctrl+c can interrupt the script and I don’t have to kill it to stop.

Hello @ptrblck adding more details,

The code that gets stuck

$ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 gpu_test_script.py
/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launch.py:177: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torch.distributed.run.
Note that --use_env is set by default in torch.distributed.run.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://pytorch.org/docs/stable/distributed.html#launch-utility for
further instructions

  warnings.warn(
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
fc1.weight Parameter containing:
tensor([[ 0.0018,  0.6833],
        [-0.4977, -0.3072]], device='cuda:0', requires_grad=True)
x tensor([[0.8429, 0.0000]], device='cuda:0', grad_fn=<FusedDropoutBackward>)
iter 0, out tensor([[0.8429, 0.0000]], device='cuda:0', grad_fn=<_DDPSinkBackward>)

upon KeyboardInterruption,

^CTraceback (most recent call last):
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launch.py", line 191, in <module>
    main()
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launch.py", line 187, in main
    launch(args)
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launch.py", line 173, in launch
    run(args)
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/run.py", line 688, in run
    elastic_launch(
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 252, in launch_agent
    result = agent.run()
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper
    result = f(*args, **kwargs)
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 700, in run
    result = self._invoke_run(role)
  File "/home/user/.conda/envs/torch1p9/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 828, in _invoke_run
    time.sleep(monitor_interval)
KeyboardInterrupt

Version,

$ python -c "import torch;print(torch.__version__)"
1.10.0.dev20210721
/usr/local/cuda-11.4/bin/nvcc
(torch1p9) user@ampere:~/Project$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Wed_Jun__2_19:15:15_PDT_2021
Cuda compilation tools, release 11.4, V11.4.48
Build cuda_11.4.r11.4/compiler.30033411_0

Ping @ptrblck , any fixes on what could be done? the current setup is very time consuming given that we can train only on just one GPU. If you could please suggest a version of Nvidia-driver, CUDA, where nn.DataParallel or nn.DistributedDataParallel is known to work on A100 GPUs, that will be very much appreciated. Thank you

I’m unfortunately unable to reproduce the issue on our A100 servers. In case you are using the bare metal, could you try to run some workloads in a docker container and check, if it’s still hanging?

Hello @ptrblck thanks for your answer,

I tried using the dockerfile here . I also tried to debug whether Ubuntu version was the reason, which isn’t as the code gets stuck on 18.04 as well. Can you please suggest what else I could try to fix this? Attached below are terminal outputs,

$ docker run --rm -it --init --gpus=2 --ipc=host --volume="$PWD:/app" anibali/pytorch:1.8.1-cuda11.1 python3 -m torch.distributed.launch --nproc_per_node=2 gpu_test_script.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
fc1.weight Parameter containing:
tensor([[ 0.2373, -0.6326],
        [-0.5230,  0.0513]], device='cuda:0', requires_grad=True)
x tensor([[-0.4446,  0.0000]], device='cuda:0', grad_fn=<FusedDropoutBackward>)
iter 0, out tensor([[-0.4446,  0.0000]], device='cuda:0', grad_fn=<FusedDropoutBackward>)
^CKilling subprocess 9
Killing subprocess 10
Main process received SIGINT, exiting

$ docker run --rm -it --init --gpus=2 --ipc=host --volume="$PWD:/app" cuda11_ubuntu1804 python3 -m torch.distributed.launch --nproc_per_node=2 gpu_test_script.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
fc1.weight Parameter containing:
tensor([[-0.1156, -0.3040],
        [-0.5875,  0.4592]], device='cuda:0', requires_grad=True)
x tensor([[0., 0.]], device='cuda:0', grad_fn=<FusedDropoutBackward>)
iter 0, out tensor([[0., 0.]], device='cuda:0', grad_fn=<FusedDropoutBackward>)

Thank you again for your time patience and help @ptrblck , much appreciate it.

I have exactly the same issue as you. Please let me if there are any solutions.

Hi @Passlab_Yaying No. I still cannot use multi-GPU. I’m adjusting for the time-being by using one GPU at for a task. I hope @ptrblck can share more insights/debugging steps.

Thanks for the update. As another test you could try to launch the container via --gpus all. If that doesn’t help, I would recommend to contact your system admin as I haven’t experienced these issues on our A100 nodes.

Hi, @Sentient07 @ptrblck I got this issue fixed by change the CPU IOMMU feature from Auto to Disabled. Thank you for your help. Let me know if it works for you.

7 Likes