Specify device_ids in barrier() to force use of a particular device

I have a problem running the spawn function from mp on Slurm on multiple GPUs.

Instructions To Reproduce the Issue:

  1. Full runnable code:
import torch, os

def test_nccl_ops():
    num_gpu = 2
    print("NCCL init before spawn")
    import torch.multiprocessing as mp
    dist_url = "file:///tmp/nccl_tmp_file"
    mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
    print("NCCL init succeeded.")


def _test_nccl_worker(rank, num_gpu, dist_url):
    import torch.distributed as dist

    dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
    dist.barrier()
    print("Worker after barrier")

if __name__ == "__main__":
    test_nccl_ops()

On the other hand, we implemented this Slurm script to run an experiment on 2 GPUs:

#!/bin/bash -l

#SBATCH --account=Account
#SBATCH --partition=gpu # gpu partition
#SBATCH --nodes=1 # 1 node, 4 GPUs per node
#SBATCH --time=24:00:00 
#SBATCH --job-name=detectron2_demo4 # job name



module load Python/3.9.5-GCCcore-10.3.0
module load CUDA/11.1.1-GCC-10.2.0

cd /experiment_path

export NCCL_DEBUG=INFO

srun python main.py --num-gpus 2

When I ran this script I faced an error (cat slurm-xxx.out), and no error file:

The following have been reloaded with a version change:
  1) GCCcore/10.3.0 => GCCcore/10.2.0
  2) binutils/2.36.1-GCCcore-10.3.0 => binutils/2.35-GCCcore-10.2.0
  3) zlib/1.2.11-GCCcore-10.3.0 => zlib/1.2.11-GCCcore-10.2.0

NCCL init before spawn
[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.
[W ProcessGroupNCCL.cpp:1569] Rank 1 using best-guess GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.
gpu04:9770:9770 [0] NCCL INFO Bootstrap : Using [0]bond0:10.10.1.4<0>
gpu04:9770:9770 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
gpu04:9770:9770 [0] NCCL INFO NET/IB : No device found.
gpu04:9770:9770 [0] NCCL INFO NET/Socket : Using [0]bond0:10.10.1.4<0>
gpu04:9770:9770 [0] NCCL INFO Using network Socket
NCCL version 2.7.8+cuda10.2
gpu04:9771:9771 [1] NCCL INFO Bootstrap : Using [0]bond0:10.10.1.4<0>
gpu04:9771:9771 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
gpu04:9771:9771 [1] NCCL INFO NET/IB : No device found.
gpu04:9771:9771 [1] NCCL INFO NET/Socket : Using [0]bond0:10.10.1.4<0>
gpu04:9771:9771 [1] NCCL INFO Using network Socket
gpu04:9771:9862 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
gpu04:9771:9862 [1] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] -1/-1/-1->1->0|0->1->-1/-1/-1
gpu04:9771:9862 [1] NCCL INFO Setting affinity for GPU 1 to 3fff
gpu04:9770:9861 [0] NCCL INFO Channel 00/02 :    0   1
gpu04:9770:9861 [0] NCCL INFO Channel 01/02 :    0   1
gpu04:9770:9861 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
gpu04:9770:9861 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] 1/-1/-1->0->-1|-1->0->1/-1/-1
gpu04:9770:9861 [0] NCCL INFO Setting affinity for GPU 0 to 3fff
gpu04:9771:9862 [1] NCCL INFO Channel 00 : 1[6000] -> 0[5000] via P2P/IPC
gpu04:9770:9861 [0] NCCL INFO Channel 00 : 0[5000] -> 1[6000] via P2P/IPC
gpu04:9771:9862 [1] NCCL INFO Channel 01 : 1[6000] -> 0[5000] via P2P/IPC
gpu04:9770:9861 [0] NCCL INFO Channel 01 : 0[5000] -> 1[6000] via P2P/IPC
gpu04:9771:9862 [1] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
gpu04:9771:9862 [1] NCCL INFO comm 0x7f057c000e00 rank 1 nranks 2 cudaDev 1 busId 6000 - Init COMPLETE
gpu04:9770:9861 [0] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
gpu04:9770:9861 [0] NCCL INFO comm 0x7f5210000e00 rank 0 nranks 2 cudaDev 0 busId 5000 - Init COMPLETE
gpu04:9770:9770 [0] NCCL INFO Launch mode Parallel


Expected behavior:

To run training on 2 GPUs and print other more outputs then “NCCL init before spawn” and NCCL debug info.

Environment:

Paste the output of the following command:

No CUDA runtime is found, using CUDA_HOME='/usr/local/software/CUDAcore/11.1.1'
---------------------  --------------------------------------------------------------------------------
sys.platform           linux
Python                 3.9.5 (default, Jul  9 2021, 09:35:24) [GCC 10.3.0]
numpy                  1.21.1
detectron2             0.5 @/home/users/aimhigh/detectron2/detectron2
Compiler               GCC 10.2
CUDA compiler          CUDA 11.1
DETECTRON2_ENV_MODULE  <not set>
PyTorch                1.9.0+cu102 @/home/users/aimhigh/.local/lib/python3.9/site-packages/torch
PyTorch debug build    False
GPU available          No: torch.cuda.is_available() == False
Pillow                 8.3.1
torchvision            0.10.0+cu102 @/home/users/aimhigh/.local/lib/python3.9/site-packages/torchvision
fvcore                 0.1.5.post20210727
iopath                 0.1.9
cv2                    4.5.3
---------------------  --------------------------------------------------------------------------------
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.1.2 
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2

Additional note: the first time I assumed it is a detectron2 problem but it’s not. You can find my previous discussion with detectron2 developers: link. Maybe dist_url is somehow problematic, we maybe need some additional Slurm configuration

Some additional example:
Here is some new example. Same thing:

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)

def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = (rank * 2) % world_size
    dev1 = (rank * 2 + 1) % world_size
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)
    # run_demo(demo_model_parallel, world_size)

Output:


The following have been reloaded with a version change:
  1) GCCcore/10.3.0 => GCCcore/10.2.0
  2) binutils/2.36.1-GCCcore-10.3.0 => binutils/2.35-GCCcore-10.2.0
  3) zlib/1.2.11-GCCcore-10.3.0 => zlib/1.2.11-GCCcore-10.2.0

/home/users/aimhigh/.local/lib/python3.9/site-packages/torch/distributed/launch.py:163: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
  logger.warn(
The module torch.distributed.launch is deprecated and going to be removed in future.Migrate to torch.distributed.run
INFO:torch.distributed.run:Using nproc_per_node=auto, seting to 8 since the instance has 28 gpu
WARNING:torch.distributed.run:--use_env is deprecated and will be removed in future releases.
 Please read local_rank from `os.environ('LOCAL_RANK')` instead.
INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
  entrypoint       : main.py
  min_nodes        : 1
  max_nodes        : 1
  nproc_per_node   : 8
  run_id           : none
  rdzv_backend     : static
  rdzv_endpoint    : 127.0.0.1:29500
  rdzv_configs     : {'rank': 0, 'timeout': 900}
  max_restarts     : 3
  monitor_interval : 5
  log_dir          : None
  metrics_cfg      : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_tp__4tqc/none_we2sza_6
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
/home/users/aimhigh/.local/lib/python3.9/site-packages/torch/distributed/elastic/utils/store.py:52: FutureWarning: This is an experimental API and will be changed in future.
  warnings.warn(
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
  restart_count=0
  master_addr=127.0.0.1
  master_port=29500
  group_rank=0
  group_world_size=1
  local_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
  role_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
  global_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
  role_world_sizes=[8, 8, 8, 8, 8, 8, 8, 8]
  global_world_sizes=[8, 8, 8, 8, 8, 8, 8, 8]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/0/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker1 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/1/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker2 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/2/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker3 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/3/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker4 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/4/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker5 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/5/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker6 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/6/error.json
INFO:torch.distributed.elastic.multiprocessing:Setting worker7 reply file to: /tmp/torchelastic_tp__4tqc/none_we2sza_6/attempt_0/7/error.json

When I check GPUs and CPUs there is almost no activity at all, but job continues to execute, and no any output after this that I sent (no changes in Slurm)

Hey @StevanCakic, for the above script, have you tried setting CUDA_VISIBLE_DEVICES for each spawned process before any torch operation, so that each process only sees one GPU? You can also try torch.cuda.set_device(), but I would recommend CUDA_VISIBLE_DEVICES, as with that you can know for sure that each process is exclusively using the expected device.

1 Like

I will check it, thank you

Finally, I solved it: Specify device_ids in ba