Failure in Initiating Pyotch DDP-style code ( Multi-machine multi-card environment )

I implemented a DDP-style pytorch program, and it can can be successfully run in single-machine multi-card Environment,
the program (demo3_ddp.py) is as follows,

#! /usr/bin/env
`# -*- coding:utf-8 -*-`
import os
import os.path as osp
import socket
from contextlib import closing

from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist
#from apex.parallel import DistributedDataParallel as DDP
#from apex import amp

from DDP.model import ConvNet


def train(gpu, args):
    ############################################################
    rank = args.node_rank * args.gpus_per_node + gpu

    assert (torch.distributed.is_nccl_available() == True)
    if args.init_method == 'ENV':
        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=args.world_size,
            rank=rank
        )
    elif args.init_method == 'TCP':
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://{}:{}'.format(args.master_addr, args.master_port),
            world_size=args.world_size,
            rank=rank
        )
    elif args.init_method == 'SFILE':
        dist.init_process_group(
            backend='nccl',
            init_method='file://{}'.format(args.shard_file),
            world_size=args.world_size,
            rank=rank
            )
    else:
        print(f'No implementation for init_method={args.init_method}')
        exit(-1)

    print(f'process/rank {rank} is initialized.')
    ############################################################

    torch.manual_seed(0)
    model = ConvNet()
    '''
    torch.cuda.set_device(gpu_id)  # set single gpu for current process
    torch.cuda.set_device('cuda:' + str(gpu_ids))  # set multiple gpus for current process
    '''
    torch.cuda.set_device(gpu) # gpu: local rank, set GPU for each process in the current node
    model.cuda(gpu)
    ###############################################################
    # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # especially for BN operation in DDP mode
    # Wrap the model
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu], output_device=gpu)
    ###############################################################

    '''
    if rank == 0:
        torch.save(model.state_dict(), osp.join(args.checkpoint_path, 'model.pt'))

    dist.barrier() # ensure save model operation finished.
    map_location = {"cuda:0": f"cuda:{gpu}"}
    model.load_state_dict(torch.load(osp.join(args.checkpoint_path, 'model.pt'), map_location=map_location)) # each process load model from rank:0
    '''

    batch_size = args.batch_size
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    # Data loading code
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    train_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        transform=trans,
        download=True
    )
    ################################################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=rank
    )
    ################################################################

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        ##############################
        shuffle=False,  #
        ##############################
        num_workers=0,
        pin_memory=True,
        #############################
        sampler=train_sampler)  #
    #############################

    os.makedirs(args.checkpoint_path, exist_ok=True)

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        # set update seed for epoch, ensure the acquired data for each process is different across epochs
        train_sampler.set_epoch(epoch)
        for i, (images, labels) in enumerate(train_loader):
            print(f'Batch, input.shape: {images.shape}')
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #if (i + 1) % 100 == 0 and gpu == 0:
            if (i + 1) % 100 == 0 and rank == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1,
                    args.epochs,
                    i + 1,
                    total_step,
                    loss.item())
                   )

        # save model per epoch
        if rank == 0:
            # note that dist.barrior() is not required, since the all-reduce operation has ensured the synchronism between processes.
            torch.save(model.state_dict(), osp.join(args.checkpoint_path, f'model_{epoch}.pt'))
    #if gpu == 0:
    #rank = dist.get_rank()
    if rank == 0:
        # only print message in the first process
        print("Training complete in: " + str(datetime.now() - start))


def evaluation(gpu, args):
    '''
    Not implemented, consider......
    '''
    ############################################################
    rank = args.node_rank * args.gpus_per_node + gpu

    if args.init_method == 'ENV':
        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=args.world_size,
            rank=rank
        )
    elif args.init_method == 'TCP':
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://172.16.21.159:8887',
            world_size=args.world_size,
            rank=rank
        )
    print(f'process/rank {rank} is initialized.')
    ############################################################

    torch.manual_seed(0)
    model = ConvNet()
    '''
    torch.cuda.set_device(gpu_id)  # set single gpu for current process
    torch.cuda.set_device('cuda:' + str(gpu_ids))  # set multiple gpus for current process
    '''
    torch.cuda.set_device(gpu) # gpu: local rank, set GPU for each process in the current node
    model.cuda(gpu)
    ###############################################################
    # Wrap the model
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    ###############################################################

    '''
    if rank == 0:
        torch.save(model.state_dict(), osp.join(args.checkpoint_path, 'model.pt'))

    dist.barrier() # ensure save model operation finished.
    map_location = {"cuda:0": f"cuda:{gpu}"}
    model.load_state_dict(torch.load(osp.join(args.checkpoint_path, 'model.pt'), map_location=map_location)) # each process load model from rank:0
    '''

    batch_size = args.batch_size
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    # Data loading code
    train_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )
    ################################################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=rank
    )
    ################################################################

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        ##############################
        shuffle=False,  #
        ##############################
        num_workers=0,
        pin_memory=True,
        #############################
        sampler=train_sampler)  #
    #############################

    os.makedirs(args.checkpoint_path, exist_ok=True)

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        # set update seed for epoch, ensure the acquired data for each process is different across epochs
        train_sampler.set_epoch(epoch)
        for i, (images, labels) in enumerate(train_loader):
            print(f'Batch, input.shape: {images.shape}')
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #if (i + 1) % 100 == 0 and gpu == 0:
            if (i + 1) % 100 == 0 and rank == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1,
                    args.epochs,
                    i + 1,
                    total_step,
                    loss.item())
                   )
        # save model per epoch
        if rank == 0:
            # note that dist.barrior() is not required, since the all-reduce operation has ensured the synchronism between processes.
            torch.save(model.state_dict(), osp.join(args.checkpoint_path, f'model_{epoch}.pt'))
    #if gpu == 0:
    #rank = dist.get_rank()
    if rank == 0:
        # only print message in the first process
        print("Training complete in: " + str(datetime.now() - start))


def get_open_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--init_method', type=str, default='ENV', choices=['TCP', 'ENV'])

    parser.add_argument('-nodes', '--nnodes', default=1, type=int, metavar='N', help='the number of nodes')
    parser.add_argument('-gpus', '--gpus_per_node', default=1, type=int, help='number of gpus per node')
    parser.add_argument('-nr', '--node_rank', default=0, type=int, help='ranking of the current nodes')

    parser.add_argument('--master_addr', type=str, default='', help='master addr')
    parser.add_argument('--master_port', type=str, default='36185', help='')

    parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--batch_size', default=100, type=int, help='')

    parser.add_argument('--checkpoint_path', type=str, default='./DDP/checkpoints/model.checkpoints', help='')
    args = parser.parse_args()


    free_port = get_open_port()
    print(f'free_port: {free_port}')


    #########################################################
    args.world_size = args.gpus_per_node * args.nnodes                #  #GPU = #Process
    os.environ['MASTER_ADDR'] = args.master_addr             #  IP addr of process-0 (node-0)
    os.environ['MASTER_PORT'] = args.master_port             #  PORT of process-0 (node-0)

    '''
    using mp.spawn
     mp.spawn(train, nprocs=args.gpus_per_node, args=(args,)) => [train(i, args) for i in [0, args.gpus_per_node - 1]] 
     '''
    mp.spawn(train, nprocs=args.gpus_per_node, args=(args,))         # start all processes in the current node using torch.mp
    #########################################################

    '''
    # using mp.Process
    processes = []
    for local_rank in range(args.gpus_per_node):
        p = mp.Process(target=train, args=(local_rank, args))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
    '''


if __name__ == '__main__':
    print(f'torch.cuda.is_available(): {torch.cuda.is_available()}')
    print(f'torch.cuda.device_count(): {torch.cuda.device_count()}')
    main()

wherein the ConvNet model (model.py) is implemented as follows,

#! /usr/bin/env
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torchvision


class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        print(f'Forward, input.shape: {x.shape}')
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out


class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        net = torchvision.models.resnet101(num_classes=num_classes)
        net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
        self.network = net

    def forward(self, x):
        return self.network(x)

My experimental environment is:
machine 1, two RTX3090;
machine 2, two RTX 4090;

the program is initiated using the following commands on two machines respectively,

    $ python -m DDP.demo3_ddp -nodes 2 -gpus 2 -nr 0  --master_addr 'your_master_machine_ip'
    $ python -m DDP.demo3_ddp -nodes 2 -gpus 2 -nr 1  --master_addr 'your_master_machine_ip'

then, the results are as follows,
on the first (master) machine, the output information is as follows,

torch.cuda.is_available(): True
torch.cuda.device_count(): 2
free_port: 37781
process/rank 0 is initialized.
process/rank 1 is initialized.

On the second machine, the output information is as follows,

torch.cuda.is_available(): True
torch.cuda.device_count(): 2
free_port: 56899

Process SpawnProcess-1:
Process SpawnProcess-2:

torch.multiprocessing.spawn.ProcessRaisedException:

– Process 1 terminated with the following error:
Traceback (most recent call last):
File “/home/tme/anaconda3/envs/pt1.10py3.9/lib/python3.9/site-packages/torch/multiprocessing/spawn.py”, line 59, in _wrap
fn(i, *args)
File “/home/qinkeke/workspace/pycharm/DDP_project/DDP/demo3_ddp.py”, line 70, in train
dist.init_process_group(
File “/home/tme/anaconda3/envs/pt1.10py3.9/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py”, line 611, in init_process_group
default_pg._set_sequence_number_for_group()
RuntimeError: Socket Timeout

It seems that the process group can not be successfully initiated. Can any one help me fix this issue?