Problem with 3 node dsitributed training with torchrun

Hi,
I have a problem with running a distributed training of pytorch using torchrun. first of all, this is the script I try to run:

import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np
import time
import importlib

if 'LOCAL_RANK' in os.environ:
    # Environment variables set by torch.distributed.launch or torchrun
    LOCAL_RANK = int(os.environ['LOCAL_RANK'])
    WORLD_SIZE = int(os.environ['WORLD_SIZE'])
    WORLD_RANK = int(os.environ['RANK'])
elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
    # Environment variables set by mpirun
    LOCAL_RANK = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
    WORLD_SIZE = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    WORLD_RANK = int(os.environ['OMPI_COMM_WORLD_RANK'])
else:
    import sys
    sys.exit("Can't find the evironment variables for local rank")

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def evaluate(model, device, test_loader):

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy


def main():

    num_epochs_default = 10000
    batch_size_default = 256 # 1024
    learning_rate_default = 0.1
    random_seed_default = 0
    model_dir_default = "saved_models"
    model_filename_default = "resnet_distributed.pth"
    w = 32
    h = 32
    c = 3
    num_steps_syn = 20

    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
    parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
    parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default)
    parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default)
    parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
    parser.add_argument("--backend", type=str, help="Backend for distribted training.", default='nccl', choices=['nccl', 'gloo', 'mpi'])
    parser.add_argument("--arch", type=str, help="Model architecture.", default='resnet50', choices=['resnet50', 'resnet18', 'resnet101', 'resnet152'])
    parser.add_argument("--use_syn", action="store_true", help="Use synthetic data")
    argv = parser.parse_args()

    local_rank = argv.local_rank
    num_epochs = argv.num_epochs
    batch_size = argv.batch_size
    learning_rate = argv.learning_rate
    random_seed = argv.random_seed
    model_dir = argv.model_dir
    model_filename = argv.model_filename
    resume = argv.resume
    backend = argv.backend
    use_syn = argv.use_syn

    # Create directories outside the PyTorch program
    # Do not create directory here because it is not multiprocess safe
    '''
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    '''

    model_filepath = os.path.join(model_dir, model_filename)

    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    print(WORLD_SIZE)
    torch.distributed.init_process_group(backend=backend, rank=WORLD_RANK, world_size=WORLD_SIZE)

    # Encapsulate the model on the GPU assigned to the current process
    model = getattr(torchvision.models, argv.arch)(pretrained=False)

    device = torch.device("cuda:{}".format(LOCAL_RANK))
    model = model.to(device)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

    # We only save the model who uses device "cuda:0"
    # To resume, the device for the saved model would also be "cuda:0"
    if resume == True:
        map_location = {"cuda:0": "cuda:{}".format(LOCAL_RANK)}
        ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

    if use_syn:
        # Synthetic data
        inputs_syn = torch.rand((batch_size, c, w, h)).to(device)
        labels_syn = torch.zeros(batch_size, dtype=torch.int64).to(device)
    else:
        # Prepare dataset and dataloader
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        # Data should be prefetched
        # Download should be set to be False, because it is not multiprocess safe
        train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=transform) 
        test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

        # Restricts data loading to a subset of the dataset exclusive to the current process
        train_sampler = DistributedSampler(dataset=train_set)

        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
        # Test loader does not have to follow distributed sampling strategy
        test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # Loop over the dataset multiple times
    times = []
    for epoch in range(num_epochs):

        print("Local Rank: {}, Epoch: {}, Training ...".format(LOCAL_RANK, epoch))
        
        # Save and evaluate model routinely
        if not use_syn:
            if epoch % 10 == 0:
                if LOCAL_RANK == 0:
                    accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader)
                    torch.save(ddp_model.state_dict(), model_filepath)
                    print("-" * 75)
                    print("Epoch: {}, Accuracy: {}".format(epoch, accuracy))
                    print("-" * 75)

        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)

        if use_syn:
            start_epoch = time.time()
            for count in range(num_steps_syn):
                optimizer.zero_grad()
                outputs = ddp_model(inputs_syn)
                loss = criterion(outputs, labels_syn)
                loss.backward()
                optimizer.step()
            torch.cuda.synchronize()
            end_epoch = time.time()
            elapsed = end_epoch - start_epoch

            if epoch > 0:
                times.append(elapsed)
                print('num_steps_per_gpu: {}, avg_step_time: {:.4f}'.format(count, elapsed / count))              
        else:
            start_epoch = time.time()
            count = 0
            for data in train_loader:
                inputs, labels = data[0].to(device), data[1].to(device)
                optimizer.zero_grad()
                outputs = ddp_model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                count += 1
            torch.cuda.synchronize()
            end_epoch = time.time()
            elapsed = end_epoch - start_epoch

            if epoch > 0:
                times.append(elapsed)
                print('num_steps_per_gpu: {}, avg_step_time: {:.4f}'.format(count, elapsed / count))

    avg_time = sum(times) / (num_epochs - 1)

    print("Average epoch time: {}".format(avg_time))

if __name__ == "__main__":
    
    main()

I want to run this on 3 nodes. The command I use to run this on each of the nodes is this:
node1:

torchrun --nproc_per_node=1 --nnodes=3 --node_rank=0 --master_addr=192.168.1.92 --master_port=8000 main.py --backend=gloo --batch_size=64 --arch=resnet152

node2:

torchrun --nproc_per_node=1 --nnodes=3 --node_rank=1 --master_addr=192.168.1.92 --master_port=8000 main.py --backend=gloo --batch_size=64 --arch=resnet152

node3:

torchrun --nproc_per_node=1 --nnodes=3 --node_rank=2 --master_addr=192.168.1.92 --master_port=8000 main.py --backend=gloo --batch_size=64 --arch=resnet152

As you can see I’m trying with gloo backend. But it hangs at the line:

torch.distributed.init_process_group(backend=backend, rank=WORLD_RANK, world_size=WORLD_SIZE)

and this is the messages I see when I set debug flags on node 1:

[I debug.cpp:49] [c10d] The debug level is set to DETAIL.
[I socket.cpp:442] [c10d - debug] The server socket will attempt to listen on an IPv6 address.
[I socket.cpp:492] [c10d - debug] The server socket is attempting to listen on [::]:8000.
[I socket.cpp:566] [c10d] The server socket has started to listen on [::]:8000.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32920.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32920.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37382.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37384.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46816.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46818.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32922.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32922.
[I debug.cpp:49] [c10d] The debug level is set to DETAIL.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46820.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46822.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37386.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37388.
gloo
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32924.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32924.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32926.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32926.

I tried to see if it works with 2 nodes, and it sometimes work on two nodes but sometimes don’t.
This was for the gloo backend.
Something like this also happens for nccl backend, but it hangs on the line

ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

and output of node1:

[I debug.cpp:49] [c10d] The debug level is set to DETAIL.
[I socket.cpp:442] [c10d - debug] The server socket will attempt to listen on an IPv6 address.
[I socket.cpp:492] [c10d - debug] The server socket is attempting to listen on [::]:8000.
[I socket.cpp:566] [c10d] The server socket has started to listen on [::]:8000.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32930.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32930.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37394.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37396.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46824.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46826.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32932.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32932.
[I debug.cpp:49] [c10d] The debug level is set to DETAIL.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46828.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc1]:46830.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37398.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc3]:37400.
nccl
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32934.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32934.
[I socket.cpp:624] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (192.168.1.92, 8000).
[I socket.cpp:699] [c10d - trace] The client socket is attempting to connect to [mortezauoc2]:8000.
[I socket.cpp:787] [c10d] The client socket has connected to [mortezauoc2]:8000 on [mortezauoc2]:32936.
[I socket.cpp:295] [c10d - debug] The server socket on [::]:8000 has accepted a connection from [mortezauoc2]:32936.
[I ProcessGroupNCCL.cpp:629] [Rank 0] NCCL_BLOCKING_WAIT and NCCL_ASYNC_ERROR_HANDLING|NCCL_DESYNC_DEBUGshould not both be enabled. Only NCCL_BLOCKING_WAIT is being used in this process.
[I ProcessGroupNCCL.cpp:665] [Rank 0] ProcessGroupNCCL initialized with following options:
NCCL_ASYNC_ERROR_HANDLING: 0
NCCL_DESYNC_DEBUG: 0
NCCL_BLOCKING_WAIT: 1
TIMEOUT(ms): 1800000
USE_HIGH_PRIORITY_STREAM: 0
[I ProcessGroupNCCL.cpp:842] [Rank 0] NCCL watchdog thread started!
HELLO
/home/morteza/torch/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/morteza/torch/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
0
mortezauoc2:243289:243289 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to ens3
mortezauoc2:243289:243289 [0] NCCL INFO NCCL_SOCKET_IFNAME set to ens3
mortezauoc2:243289:243289 [0] NCCL INFO Bootstrap : Using ens3:192.168.1.92<0>
mortezauoc2:243289:243289 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
mortezauoc2:243289:243289 [0] NCCL INFO cudaDriverVersion 11070
NCCL version 2.14.3+cuda11.7
mortezauoc2:243289:243289 [0] NCCL INFO init.cc:1147 Cuda Host Alloc Size 4 pointer 0x7f83e2000000
mortezauoc2:243289:243312 [0] NCCL INFO Failed to open libibverbs.so[.1]
mortezauoc2:243289:243312 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to ens3
mortezauoc2:243289:243312 [0] NCCL INFO NET/Socket : Using [0]ens3:192.168.1.92<0>
mortezauoc2:243289:243312 [0] NCCL INFO Using network Socket

Each of these nodes are VM environments, this maybe have something to do with rootcause of the problem but I don’t know how I can investigate that. this is the output of ifconfig:

$ ifconfig
ens3: flags=4163<UP,BROADCAST,RUNNING,MULTICAST>  mtu 1450
        inet 192.168.1.92  netmask 255.255.255.0  broadcast 192.168.1.255
        inet6 fe80::f816:3eff:fee7:862  prefixlen 64  scopeid 0x20<link>
        ether fa:16:3e:e7:08:62  txqueuelen 1000  (Ethernet)
        RX packets 18581866  bytes 509072745967 (509.0 GB)
        RX errors 0  dropped 0  overruns 0  frame 0
        TX packets 12684239  bytes 503527152043 (503.5 GB)
        TX errors 0  dropped 0 overruns 0  carrier 0  collisions 0

lo: flags=73<UP,LOOPBACK,RUNNING>  mtu 65536
        inet 127.0.0.1  netmask 255.0.0.0
        inet6 ::1  prefixlen 128  scopeid 0x10<host>
        loop  txqueuelen 1000  (Local Loopback)
        RX packets 1380710  bytes 80791900 (80.7 MB)
        RX errors 0  dropped 0  overruns 0  frame 0
        TX packets 1380710  bytes 80791900 (80.7 MB)
        TX errors 0  dropped 0 overruns 0  carrier 0  collisions 0

Do you know what can be the cause of the problem? or how I can find that out?