`torch.distributed.barrier` used in multi-node distributed data-parallel training

Thank you very much @iffiX. I will try gloo tomorrow.
Best,
Lei

@iffiX @mrshenli I just got time to test the gloo backend. It seems that the training could be run without significant problems. However, I do have concerns. I found the number of processes is 7 on each node despite the fact that I requested using 4 GPU on each node.

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     53447      C   /opt/conda/bin/python                       1511MiB |
|    0     53448      C   /opt/conda/bin/python                        803MiB |
|    0     53449      C   /opt/conda/bin/python                        803MiB |
|    0     53450      C   /opt/conda/bin/python                        803MiB |
|    1     53448      C   /opt/conda/bin/python                       1511MiB |
|    2     53449      C   /opt/conda/bin/python                       1511MiB |
|    3     53450      C   /opt/conda/bin/python                       1511MiB |
+-----------------------------------------------------------------------------+

The GPU memory usages are not even as well.

$ nvidia-smi
Tue Jul 21 19:49:09 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.126.02   Driver Version: 418.126.02   CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   38C    P0    49W / 163W |   3933MiB / 32480MiB |     11%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   38C    P0    46W / 163W |   1522MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:0A:00.0 Off |                    0 |
| N/A   38C    P0    46W / 163W |   1522MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:0B:00.0 Off |                    0 |
| N/A   37C    P0    48W / 163W |   1522MiB / 32480MiB |      9%      Default |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  On   | 00000000:85:00.0 Off |                    0 |
| N/A   36C    P0    42W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  On   | 00000000:86:00.0 Off |                    0 |
| N/A   38C    P0    43W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  On   | 00000000:89:00.0 Off |                    0 |
| N/A   38C    P0    43W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:8A:00.0 Off |                    0 |
| N/A   37C    P0    41W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

Can you guys explain what’s happening here?

Regarding the nccl backend problem, I currently don’t have time to troubleshoot at more lower level. But I believe it is a bug, either in the nccl library or in the PyTorch implementation.

Thank you.

Best,

Lei

This is not an error, as you can see:

 0     53448      C   /opt/conda/bin/python                        803MiB
 1     53448      C   /opt/conda/bin/python                       1511MiB

Their PID are the same, it seems that DDP will spawn an additional process for all “secondary processes”, except the “primary process”, probably for receiving tensors etc. “803MiB” should be the base kernel memory usage, if you spawn any process using cuda in pytorch. Actions such as moving a model to gpu, creating a tensor on gpu will invoke cuda. see this issue for detail explainations: issue

I can also replicate this behavior on my machine, so don’t worry about it:

The replication script is a slightly modified version from the DDP tutorial:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_PORT"]="9003"

def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    while True:
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Thank you @iffiX. I used to always use nccl. In my impression, I remember the GPU memory occupancy is always the same for each GPU on each node.

Then it is implementation related.

Sorry for being late to the discussion.

I saw you were using ddp_model.load_state_dict to load model parameters. Is this method untested and unfavored?

Right, we don’t have tests for saving.loading DDP models yet, IIUC. Let me create an issue to track.

So the second node got halted in

DDP constructor does have a broadcast op, I believe that’s where it is halted:

Looking at the log, some ranks proceed beyond 2.1 while others are waiting at 2.1, which suggest there is a desync across all processes. Curious, why there is no output for Location 0 at rank 0? Is it just because the print for Location 0 is actually in the if clause?

For the log, can you try also print dist.get_world_size(), and then use dist.get_rank() instead of local rank? Let’s verify if the launching script did anything wrong.

I found the number of processes is 7 on each node despite the fact that I requested using 4 GPU on each node.

Looks like other processes (local_rank != 0) also created CUDA context and allocated some tensor on cuda:0. You can avoid this by setting CUDA_VISIBLE_DEVICES variable for each subprocess, either directly in command line or in the program before loading any cuda logic. See Running on specific GPU device
Note that after this change, you will also need to change all f'cuda:{local_rank}' to cuda:0 as each process now only sees one device.

hmm, this is weird. Gloo backend works means that all ranks and world sizes are configured properly. Let’s still double check using dist.get_world_size() and dist.get_rank().

If this is the case, then the broadcast in DDP might not be the place that caused the hang. Do you have access to PyTorch python files in your local env? Can you try adding some print to wrap this?

@leimao one more question regarding your test env. Would I be correct if I assume you have two 8-GPU machines, and you are using the first 4 GPUs (cuda:0-3) on those two machines, and you have exclusive access to those GPUs?

Yes. I have 8 GPUs on each node, but I just used 4 of them. I could have been using all of them.