`torch.distributed.barrier` used in multi-node distributed data-parallel training

For now, resume is always False during my test, i.e., it is always training from scratch. So we could safely ignore those code for now.

1 Like

To put it simply, if you just want process to execute mkdir, download, etc, then you should:

import torch
import argparse


def main():
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    local_rank = args.local_rank
    
    torch.distributed.barrier()

    if local_rank == 0:
        print(local_rank)
    
    torch.distributed.barrier()

    print("{} exit".format(local_rank))


if __name__ == "__main__":
    main()

this will print:

0
0 exit
2 exit
1 exit3 exit

And should not

import torch
import argparse


def main():
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    local_rank = args.local_rank
    
    if local_rank != 0:
        torch.distributed.barrier()

    print(local_rank)
    
    if local_rank == 0:
        torch.distributed.barrier()

    print("{} exit".format(local_rank))


if __name__ == "__main__":
    main()

which will print

0
0 exit
2
2 exit
13
3 exit

1 exit

barrier is just a barrier, it requires all processes in the group to reach one barrier function, no matter where it is placed, so the second function basically delays all other processes (except 0), unless the code in between two barriers is a not-effective (equal to return / pass) once any process has executed it (Eg: process 0), you are not going to get your expected result.

And please make sure that your CUDA runtime has the same major & minor version as your the CUDA version your torch you have built with, 9 is not compatible with 10, so you are likely to experience some issues when using “nccl” or cuda tensor computations.

1 Like

Thank you very much for repeating all the experiments @iffiX. I wanted to download CIFAR-10 dataset using local rank 0, and once the local rank 0 has downloaded the dataset, local rank 1, 2, and 3 could proceed and use the downloaded cache for data preprocessing.

    train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=transform) 
    test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

However, I don’t see your solution,

import torch
import argparse


def main():
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    local_rank = args.local_rank
    
    torch.distributed.barrier()

    if local_rank == 0:
        print(local_rank)
    
    torch.distributed.barrier()

    print("{} exit".format(local_rank))


if __name__ == "__main__":
    main()

in particular, is able to do this.
The printout of your second code snippet, in particular,

def main():
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
    local_rank = args.local_rank
    
    if local_rank != 0:
        torch.distributed.barrier()

    print(local_rank)
    
    if local_rank == 0:
        torch.distributed.barrier()

    print("{} exit".format(local_rank))


if __name__ == "__main__":
    main()

is expected and it is also what I was trying to implement. I want local rank 0 to do all the stuff once, then local rank 1, 2, and 3 start to the stuff in their own processes.

I think my CUDA version is compatible with PyTorch. I am using CUDA 10.2 + PyTorch 1.51.

The “asynchronous barrier” was also used in the HuggingFace example that I mentioned above. Since many people are using HuggingFace, I think their code at least runs fine on single node.

I thought of inelegant way to get around:

import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def evaluate(model, device, test_loader):

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy

def main():

    num_epochs_default = 100
    batch_size_default = 256 # 1024
    learning_rate_default = 0.1
    random_seed_default = 0
    model_dir_default = "saved_models"
    model_filename_default = "resnet_distributed.pth"

    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
    parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
    parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default)
    parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default)
    parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
    argv = parser.parse_args()

    local_rank = argv.local_rank
    num_epochs = argv.num_epochs
    batch_size = argv.batch_size
    learning_rate = argv.learning_rate
    random_seed = argv.random_seed
    model_dir = argv.model_dir
    model_filename = argv.model_filename
    resume = argv.resume

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")
    # torch.distributed.init_process_group(backend="gloo")

    # torch.distributed.barrier()
    # Create directories outside the PyTorch program
    # Only create directory in one process because it is not multiprocess safe
    if local_rank == 0:
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])


    if local_rank == 0:
        train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=transform) 
        test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
        
    torch.distributed.barrier()

    train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=transform) 
    test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

    model_filepath = os.path.join(model_dir, model_filename)


    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)

    # Encapsulate the model on the GPU assigned to the current process
    model = torchvision.models.resnet18(pretrained=False)

    device = torch.device("cuda:{}".format(local_rank))
    model = model.to(device)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

    # We only save the model who uses device "cuda:0"
    # To resume, the device for the saved model would also be "cuda:0"
    if resume == True:
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

    # Restricts data loading to a subset of the dataset exclusive to the current process
    train_sampler = DistributedSampler(dataset=train_set)

    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
    # Test loader does not have to follow distributed sampling strategy
    test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # Loop over the dataset multiple times
    for epoch in range(num_epochs):

        print("Local Rank: {}, Epoch: {}, Training ...".format(local_rank, epoch))
        
        # Save and evaluate model routinely
        if epoch % 10 == 0:
            if local_rank == 0:
                accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader)
                torch.save(ddp_model.state_dict(), model_filepath)
                print("-" * 75)
                print("Epoch: {}, Accuracy: {}".format(epoch, accuracy))
                print("-" * 75)

        ddp_model.train()

        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    
    main()

But it still got stuck.
On node 0:

100.0%Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Local Rank: 3, Epoch: 0, Training ...
Local Rank: 2, Epoch: 0, Training ...
Local Rank: 1, Epoch: 0, Training ...
Local Rank: 0, Epoch: 0, Training ...

On node 1:

100.0%Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified

@mrshenli I commented the model saving code but still got halted.

after reading your code a little bit more carefully I agree that you may use the second solution since all processes needs to create the data loader, so the problem is not there.
Could you please try to add some printing functions such as:

print("line230")
...
print("line232")

to show exactly where you code has halted? current log is way to limited to determine the exact statement which caused you code to halt.
And don’t forget to take care of ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location)) after solving the halting issue, as @mrshenli said.

@mrshenli In your tutorial (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints), I saw you were using ddp_model.load_state_dict to load model parameters. Is this method untested and unfavored?
I remember the example I documented in my blog post works perfectly. I tested model resuming a while ago and it worked fine. It’s having problems only when I tried to add some barrier functions a few days ago.
Thank you.

@iffiX @mrshenli It seems that I have located where the halting is happening. Running the following code:

import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def evaluate(model, device, test_loader):

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy

def main():

    num_epochs_default = 100
    batch_size_default = 256 # 1024
    learning_rate_default = 0.1
    random_seed_default = 0
    model_dir_default = "saved_models"
    model_filename_default = "resnet_distributed.pth"

    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
    parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
    parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default)
    parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default)
    parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
    argv = parser.parse_args()

    local_rank = argv.local_rank
    num_epochs = argv.num_epochs
    batch_size = argv.batch_size
    learning_rate = argv.learning_rate
    random_seed = argv.random_seed
    model_dir = argv.model_dir
    model_filename = argv.model_filename
    resume = argv.resume

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")
    # torch.distributed.init_process_group(backend="gloo")

    if local_rank != 0:
        torch.distributed.barrier()
    
    print("Local Rank: {} | Location: {}".format(local_rank, 0))

    # Create directories outside the PyTorch program
    # Only create directory in one process because it is not multiprocess safe
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=transform) 
    test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

    model_filepath = os.path.join(model_dir, model_filename)

    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)

    # Encapsulate the model on the GPU assigned to the current process
    model = torchvision.models.resnet18(pretrained=False)

    print("Local Rank: {} | Location: {}".format(local_rank, 1))

    if local_rank == 0:
        torch.distributed.barrier()

    print("Local Rank: {} | Location: {}".format(local_rank, 2))

    device = torch.device("cuda:{}".format(local_rank))
    model = model.to(device)
    print("Local Rank: {} | Location: {}".format(local_rank, 2.1))
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
    print("Local Rank: {} | Location: {}".format(local_rank, 2.2))

    # We only save the model who uses device "cuda:0"
    # To resume, the device for the saved model would also be "cuda:0"
    if resume == True:
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

    
    # Restricts data loading to a subset of the dataset exclusive to the current process
    train_sampler = DistributedSampler(dataset=train_set)
    

    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
    # Test loader does not have to follow distributed sampling strategy
    test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)
    print("Local Rank: {} | Location: {}".format(local_rank, 2.3))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # Loop over the dataset multiple times
    for epoch in range(num_epochs):

        print("Local Rank: {}, Epoch: {}, Training ...".format(local_rank, epoch))

        print("Local Rank: {} | Location: {}".format(local_rank, 3))
        
        # Save and evaluate model routinely
        if epoch % 10 == 0:
            if local_rank == 0:
                accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader)
                torch.save(ddp_model.state_dict(), model_filepath)
                print("-" * 75)
                print("Epoch: {}, Accuracy: {}".format(epoch, accuracy))
                print("-" * 75)

        print("Local Rank: {} | Location: {}".format(local_rank, 4))

        ddp_model.train()

        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    
    main()

For the node 0:

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Local Rank: 0 | Location: 1
Local Rank: 0 | Location: 2
Local Rank: 2 | Location: 0
Local Rank: 3 | Location: 0
Local Rank: 1 | Location: 0
Local Rank: 0 | Location: 2.1
Local Rank: 0 | Location: 2.2
Local Rank: 0 | Location: 2.3
Local Rank: 0, Epoch: 0, Training ...
Local Rank: 0 | Location: 3
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Local Rank: 2 | Location: 1
Local Rank: 2 | Location: 2
Local Rank: 1 | Location: 1
Local Rank: 1 | Location: 2
Local Rank: 3 | Location: 1
Local Rank: 3 | Location: 2
Local Rank: 2 | Location: 2.1
Local Rank: 1 | Location: 2.1
Local Rank: 3 | Location: 2.1
Local Rank: 2 | Location: 2.2
Local Rank: 2 | Location: 2.3
Local Rank: 1 | Location: 2.2
Local Rank: 1 | Location: 2.3
Local Rank: 2, Epoch: 0, Training ...
Local Rank: 2 | Location: 3
Local Rank: 2 | Location: 4
Local Rank: 1, Epoch: 0, Training ...
Local Rank: 1 | Location: 3
Local Rank: 1 | Location: 4
Local Rank: 3 | Location: 2.2
Local Rank: 3 | Location: 2.3
Local Rank: 3, Epoch: 0, Training ...
Local Rank: 3 | Location: 3
Local Rank: 3 | Location: 4

For the node 1:

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Local Rank: 0 | Location: 1
Local Rank: 0 | Location: 2
Local Rank: 2 | Location: 0
Local Rank: 3 | Location: 0
Local Rank: 1 | Location: 0
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Local Rank: 0 | Location: 2.1
Local Rank: 2 | Location: 1
Local Rank: 2 | Location: 2
Local Rank: 1 | Location: 1
Local Rank: 1 | Location: 2
Local Rank: 3 | Location: 1
Local Rank: 3 | Location: 2
Local Rank: 2 | Location: 2.1
Local Rank: 1 | Location: 2.1
Local Rank: 3 | Location: 2.1

So the second node got halted in

ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

Since you are running 1.5.1, I just dive into 1.5.1 code and can verify that the newest DistributedDataParallel do have a _sync_params class method which will broadcast all parameters and buffers, then set local params with inplace operation _set:

def _sync_params(self):
        with torch.no_grad():
            # only do intra-node parameters sync for replicated single-device
            # CUDA modules
            if self.device_ids and len(self.device_ids) > 1:
                # intra-node parameter sync
                result = torch.cuda.comm.broadcast_coalesced(
                    self.modules_params[0],
                    self.device_ids,
                    self.broadcast_bucket_size)
                for tensors, module_params in zip(result[1:],
                                                  self.modules_params[1:]):
                    for tensor, param in zip(tensors, module_params):
                        param.set_(tensor)
                        # Assume we have just run the optimizer and zeroed the
                        # grads of the parameters on the root model. We need
                        # to zero the grads on all model replicas as well.
                        # This snippet is copied from torch.optim.Optimizer.
                        if param.grad is not None:
                            param.grad.detach_()
                            param.grad.zero_()

And _sync_params will be invoked when you perform a forward operation, if syncing is enabled:

def forward(self, *inputs, **kwargs):
        if self.require_forward_param_sync:
            self._sync_params()

so load_state_dict() should work, theoretically, because newly loaded params will be broadcasted to other processes.
Sorry about my outdated knowledge above

I think your code is correct, there really isn’t any visible issue with:

    model = model.to(device)
    print("Local Rank: {} | Location: {}".format(local_rank, 2.1))
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

My knowledge is not enough to explain this behavior, some possible debug solutions:

  1. will “gloo” halt?
  2. insert some more print tracer into pytorch source code

It most likely would be a problem of nccl becasue DDP basically does these things in initialization:

  1. call dist._broadcast_coleased to broadcast parameters to all groups

    dist._broadcast_coleased is defined in torch/csrc/distributed/c10d/comm.cpp,
    however, since it is a private function, there is no indication about whether it is blocking etc, I only know that it is invoked by all processes.

  2. call _ddp_init_helper, which basically only do some local operations like:

    Initialization helper function that does the following:
    
         (1) replicating the module from device[0] to the other devices
         (2) bucketing the parameters for reductions
         (3) resetting the bucketing states
         (4) registering the grad hooks
         (5) passing a handle of DDP to SyncBatchNorm Layer
    

You can check nccl installation with, but this might not help you much if the “gloo” backend also halts:

:slightly_frowning_face: Sorry that I cannot help you more with this problem.

Thank you very much @iffiX. I will try gloo tomorrow.
Best,
Lei

@iffiX @mrshenli I just got time to test the gloo backend. It seems that the training could be run without significant problems. However, I do have concerns. I found the number of processes is 7 on each node despite the fact that I requested using 4 GPU on each node.

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     53447      C   /opt/conda/bin/python                       1511MiB |
|    0     53448      C   /opt/conda/bin/python                        803MiB |
|    0     53449      C   /opt/conda/bin/python                        803MiB |
|    0     53450      C   /opt/conda/bin/python                        803MiB |
|    1     53448      C   /opt/conda/bin/python                       1511MiB |
|    2     53449      C   /opt/conda/bin/python                       1511MiB |
|    3     53450      C   /opt/conda/bin/python                       1511MiB |
+-----------------------------------------------------------------------------+

The GPU memory usages are not even as well.

$ nvidia-smi
Tue Jul 21 19:49:09 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.126.02   Driver Version: 418.126.02   CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   38C    P0    49W / 163W |   3933MiB / 32480MiB |     11%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   38C    P0    46W / 163W |   1522MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:0A:00.0 Off |                    0 |
| N/A   38C    P0    46W / 163W |   1522MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:0B:00.0 Off |                    0 |
| N/A   37C    P0    48W / 163W |   1522MiB / 32480MiB |      9%      Default |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  On   | 00000000:85:00.0 Off |                    0 |
| N/A   36C    P0    42W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  On   | 00000000:86:00.0 Off |                    0 |
| N/A   38C    P0    43W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  On   | 00000000:89:00.0 Off |                    0 |
| N/A   38C    P0    43W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:8A:00.0 Off |                    0 |
| N/A   37C    P0    41W / 163W |     11MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

Can you guys explain what’s happening here?

Regarding the nccl backend problem, I currently don’t have time to troubleshoot at more lower level. But I believe it is a bug, either in the nccl library or in the PyTorch implementation.

Thank you.

Best,

Lei

This is not an error, as you can see:

 0     53448      C   /opt/conda/bin/python                        803MiB
 1     53448      C   /opt/conda/bin/python                       1511MiB

Their PID are the same, it seems that DDP will spawn an additional process for all “secondary processes”, except the “primary process”, probably for receiving tensors etc. “803MiB” should be the base kernel memory usage, if you spawn any process using cuda in pytorch. Actions such as moving a model to gpu, creating a tensor on gpu will invoke cuda. see this issue for detail explainations: issue

I can also replicate this behavior on my machine, so don’t worry about it:

The replication script is a slightly modified version from the DDP tutorial:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_PORT"]="9003"

def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    while True:
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Thank you @iffiX. I used to always use nccl. In my impression, I remember the GPU memory occupancy is always the same for each GPU on each node.

Then it is implementation related.

Sorry for being late to the discussion.

I saw you were using ddp_model.load_state_dict to load model parameters. Is this method untested and unfavored?

Right, we don’t have tests for saving.loading DDP models yet, IIUC. Let me create an issue to track.

So the second node got halted in

DDP constructor does have a broadcast op, I believe that’s where it is halted:

Looking at the log, some ranks proceed beyond 2.1 while others are waiting at 2.1, which suggest there is a desync across all processes. Curious, why there is no output for Location 0 at rank 0? Is it just because the print for Location 0 is actually in the if clause?

For the log, can you try also print dist.get_world_size(), and then use dist.get_rank() instead of local rank? Let’s verify if the launching script did anything wrong.

I found the number of processes is 7 on each node despite the fact that I requested using 4 GPU on each node.

Looks like other processes (local_rank != 0) also created CUDA context and allocated some tensor on cuda:0. You can avoid this by setting CUDA_VISIBLE_DEVICES variable for each subprocess, either directly in command line or in the program before loading any cuda logic. See Running on specific GPU device
Note that after this change, you will also need to change all f'cuda:{local_rank}' to cuda:0 as each process now only sees one device.

hmm, this is weird. Gloo backend works means that all ranks and world sizes are configured properly. Let’s still double check using dist.get_world_size() and dist.get_rank().

If this is the case, then the broadcast in DDP might not be the place that caused the hang. Do you have access to PyTorch python files in your local env? Can you try adding some print to wrap this?

@leimao one more question regarding your test env. Would I be correct if I assume you have two 8-GPU machines, and you are using the first 4 GPUs (cuda:0-3) on those two machines, and you have exclusive access to those GPUs?

Yes. I have 8 GPUs on each node, but I just used 4 of them. I could have been using all of them.