Questions about Model Parallelism and DDP with NCCL backend

Hi, I got a huge model with a large image dataset to run so I’m trying to use model parallelism and DDP at the same time just like the part 3 in this tutorial.

However, when I was running the tutorial for trying DDP with NCCL backend, I’m facing the same problem just like this post:

NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device 1000
NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1000

The discussion also reached to the solution:

net.to(f'cuda:{args.local_rank}')

But I just don’t know where and how can I put this line in the right place with the spawn

function in the tutorial. Can anyone provide a sample code?


Another question is how to use cleanup function (provided as below) correctly?

def cleanup():
    dist.destroy_process_group()

Should we cleanup like after each epoch, batch or just at the very end of the program?


I’m using a single machine with 4 GPUs (A100*4):

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Graphics Device     On   | 00000000:01:00.0 Off |                    0 |
| N/A   40C    P0    65W / 275W |      0MiB / 81252MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  Graphics Device     On   | 00000000:47:00.0 Off |                    0 |
| N/A   39C    P0    66W / 275W |      0MiB / 81252MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  Graphics Device     On   | 00000000:81:00.0 Off |                    0 |
| N/A   39C    P0    66W / 275W |      0MiB / 81252MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  DGX Display         On   | 00000000:C1:00.0  On |                  N/A |
| 33%   48C    P8    N/A /  50W |    641MiB /  3911MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  Graphics Device     On   | 00000000:C2:00.0 Off |                    0 |
| N/A   39C    P0    62W / 275W |      0MiB / 81252MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

Python version: 3.9.12
Pytorch version: 1.10.2
torchvision version: 0.11.3


I’m a beginner in Multi-GPU and DDP, so any suggestions or advises would be very helpful.
Much appreciated.

1 Like

Could you post a minimal code snippet which is reproducing the error? The tutorial works fine for me and doesn’t reuse the same device for different ranks.

1 Like

Hi @ptrblck ,

Thanks for replying. Yes, here’s my minimal code:

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['NCCL_P2P_DISABLE'] = '0'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)

def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = (rank * 2) % world_size
    dev1 = (rank * 2 + 1) % world_size
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
    optimizer.zero_grad()

    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))


    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = 4
    run_demo(demo_model_parallel, world_size)

Basically I’ve just changed the gloo to nccl and set the world_size=4.
Thanks!

Thanks for the code!
Since you are using NCCL and want to use a single process per device, I think you would need to change the module placement logic a bit.

I.e. for 8 GPUs this works:

    dev0 = (rank * 2) % (world_size*2)
    dev1 = (rank * 2 + 1) % (world_size*2)
    mp_model = ToyMpModel(dev0, dev1)

...
if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"

    world_size = n_gpus//2
    run_demo(demo_model_parallel, world_size)

Which places the 4 models onto all 8 GPUs (each model is split onto 2 GPUs).

2 Likes

Thank you for helping! It works!

And I’m also wondering how and where should I call the cleanup function using DDP.
Should it be at the very end of whole process?

Yes, you can call cleanup once your workload finished.

2 Likes

Thanks for answering!