Scatter operation does not work when there is more than one node

I try to run the sample and it works if run on 1 node, but crashes if run on 2 nodes.
Example from: Distributed communication package - torch.distributed — PyTorch 2.2 documentation
torch.distribubuted.broadcast() works correctly with 2 nodes.
Can you please tell me where I went wrong?

It’s working:
NCCL_IB_TIMEOUT=20 NCCL_DEBUG=INFO torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=server_0.local:12345 elastic_ddp.py

It’s not working:
master-node:
NCCL_IB_TIMEOUT=20 NCCL_DEBUG=INFO torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=server_0.local:12345 elastic_ddp.py

second-node:
NCCL_IB_TIMEOUT=20 NCCL_DEBUG=INFO torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=server_0.local:12345 elastic_ddp.py

Code:

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import time
import os

from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from datetime import timedelta

class TensorDataset(Dataset):
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "ERROR!"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1024)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1024, 5)
        self.il = nn.Linear(1024, 1024)

    def forward(self, x):
        return self.net2(self.il(self.il(self.il(self.il(self.relu(self.net1(x)))))))


def demo_basic(rank, batch):
    startTime = time.monotonic()

    device_id = rank % torch.cuda.device_count()
    model = ToyModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    inputs, labels = batch
    inputs = inputs.to(device_id)
    labels = labels.to(device_id)

    optimizer.zero_grad()
    outputs = ddp_model(inputs)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()
    return time.monotonic() - startTime

if __name__ == "__main__":

    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    local_rank = int(os.environ['LOCAL_RANK'])
    tensor_size = 2
    t_ones = torch.ones(tensor_size).cuda(local_rank)
    t_fives = torch.ones(tensor_size).cuda(local_rank) * 5
    output_tensor = torch.zeros(tensor_size).cuda(local_rank)
    if dist.get_rank() == 0:
        # Assumes world_size of 2.
        # Only tensors, all of which must be the same size.
        scatter_list = [t_ones, t_fives]
    else:
        scatter_list = None
    dist.scatter(output_tensor, scatter_list, src=0, async_op=True)
    
    full_batch_size = 90000
    a = torch.tensor(([i for i in range(full_batch_size) for _ in range(10)]), dtype=torch.float32).view(full_batch_size, 10)
    b = torch.zeros((full_batch_size, 5), dtype=torch.float32)
    
    dataset = TensorDataset(a, b)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    dataloader = DataLoader(dataset, batch_size=3000, sampler=sampler)

    common_time = []
    for batch in dataloader:
        common_time.append(demo_basic(rank, batch))
        time.sleep(0.01)
    print(sum(common_time) / len(common_time))
    dist.destroy_process_group()

Log master node:

root@server_0:/prj/test/ddp_1# NCCL_IB_TIMEOUT=20 NCCL_DEBUG=INFO torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_e
ndpoint=server_0.local:12345 elastic_ddp.py
[2024-04-19 13:53:30,495] torch.distributed.run: [WARNING] master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
server_0:27160:27160 [0] NCCL INFO Bootstrap : Using ens3:10.1.1.40<0>
server_0:27160:27160 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
server_0:27160:27160 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
server_0:27160:27160 [0] NCCL INFO cudaDriverVersion 12040
NCCL version 2.18.6+cuda12.1
server_0:27160:27177 [0] NCCL INFO Failed to open libibverbs.so[.1]
server_0:27160:27177 [0] NCCL INFO NET/Socket : Using [0]ens3:10.1.1.40<0>
server_0:27160:27177 [0] NCCL INFO Using network Socket
server_0:27160:27177 [0] NCCL INFO comm 0x8e59270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 60 commId 0x6346658378e969cf - Init START
server_0:27160:27177 [0] NCCL INFO Channel 00/02 :    0   1
server_0:27160:27177 [0] NCCL INFO Channel 01/02 :    0   1
server_0:27160:27177 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1
server_0:27160:27177 [0] NCCL INFO P2P Chunksize set to 131072
server_0:27160:27177 [0] NCCL INFO Channel 00/0 : 1[0] -> 0[0] [receive] via NET/Socket/0
server_0:27160:27177 [0] NCCL INFO Channel 01/0 : 1[0] -> 0[0] [receive] via NET/Socket/0
server_0:27160:27177 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[0] [send] via NET/Socket/0
server_0:27160:27177 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[0] [send] via NET/Socket/0
server_0:27160:27177 [0] NCCL INFO Connected all rings
server_0:27160:27177 [0] NCCL INFO Connected all trees
server_0:27160:27177 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
server_0:27160:27177 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
server_0:27160:27177 [0] NCCL INFO comm 0x8e59270 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 60 commId 0x6346658378e969cf - Init COMPLETE
server_0:27160:27180 [0] NCCL INFO Channel 00/1 : 0[0] -> 1[0] [send] via NET/Socket/0/Shared
server_0:27160:27180 [0] NCCL INFO Channel 01/1 : 0[0] -> 1[0] [send] via NET/Socket/0/Shared

server_0:27160:27178 [0] misc/socket.cc:483 NCCL WARN socketStartConnect: Connect to fe80::3cb5:9cff:fed8:77d3%7<41379> failed : Network is unreachable
server_0:27160:27178 [0] NCCL INFO misc/socket.cc:564 -> 2
server_0:27160:27178 [0] NCCL INFO misc/socket.cc:618 -> 2
server_0:27160:27178 [0] NCCL INFO transport/net_socket.cc:333 -> 2
server_0:27160:27178 [0] NCCL INFO transport/net.cc:592 -> 2
server_0:27160:27178 [0] NCCL INFO proxy.cc:1306 -> 2
server_0:27160:27178 [0] NCCL INFO proxy.cc:1377 -> 2

server_0:27160:27178 [0] proxy.cc:1519 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 2

server_0:27160:27180 [0] misc/socket.cc:49 NCCL WARN socketProgress: Connection closed by remote peer server_0.mcs.local<48211>
server_0:27160:27180 [0] NCCL INFO misc/socket.cc:749 -> 6

server_0:27160:27180 [0] proxy.cc:1143 NCCL WARN Socket recv failed while polling for opId=0x7fa1bf83baf8
server_0:27160:27180 [0] NCCL INFO transport/net.cc:288 -> 3
server_0:27160:27180 [0] NCCL INFO transport.cc:148 -> 3
server_0:27160:27180 [0] NCCL INFO group.cc:111 -> 3
server_0:27160:27180 [0] NCCL INFO group.cc:65 -> 3 [Async thread]
server_0:27160:27160 [0] NCCL INFO group.cc:406 -> 3
server_0:27160:27160 [0] NCCL INFO group.cc:96 -> 3
Traceback (most recent call last):
  File "/prj/test/ddp_1/elastic_ddp.py", line 86, in <module>
    dist.scatter(output_tensor, scatter_list, src=0, async_op=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3174, in scatter
    work = default_pg.scatter(output_tensors, input_tensors, opts)
RuntimeError: NCCL Error 3: internal error - please report this issue to the NCCL developers
server_0:27160:27160 [0] NCCL INFO comm 0x8e59270 rank 0 nranks 2 cudaDev 0 busId 60 - Abort COMPLETE
[2024-04-19 13:53:38,727] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 27160) of binary: /opt/conda/bin/python3
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
elastic_ddp.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-04-19_13:53:38
  host      : server_0.mcs.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 27160)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================