`RuntimeError: Detected mismatch between collectives on ranks` SequenceNumber mismatch on multi-GPU training

Hello! I’m having an issue where during DistributedDataParallel (DDP) synchronizations, I am receiving a RuntimeError: Detected mismatch between collectives on ranks where Collectives differ in the following aspects: Sequence number: 6vs 66.

I am able to reproduce this in a minimal way by taking the example code from the DDP tutorial for a basic torchrun (Getting Started with Distributed Data Parallel — PyTorch Tutorials 2.2.0+cu121 documentation) and modifying it to use Gloo, and perform two optimization steps each followed by a torch.distributed.monitored_barrier. The complete code is below.

This occurs when using distributed training with multiple GPUs. I’ve tried it on multiple distributed systems, one using SLURM and one using PBS, with significantly different server setups, so I don’t believe it is related to the clusters being used.

I’m able to find very little information on what a SequenceNumber mismatch entails.

Any suggestions would be greatly appreciated. Thank you for your time!

Python script

import datetime

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    dist.init_process_group("gloo")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")

    device_id = rank % torch.cuda.device_count()
    model = ToyModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    torch.distributed.monitored_barrier(timeout=datetime.timedelta(minutes=5))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    torch.distributed.monitored_barrier(timeout=datetime.timedelta(minutes=5))

    dist.destroy_process_group()

if __name__ == "__main__":
    demo_basic()

SLURM shell script

#!/bin/bash

#SBATCH --job-name="¯\\_(ツ)_/¯"
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=2
#SBATCH --cpus-per-task=40
#SBATCH --mem=600000
#SBATCH --time=5-00:00:00

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

export LOGLEVEL=INFO
export TORCH_CPP_LOG_LEVEL=INFO
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export CUDA_LAUNCH_BLOCKING=1

srun torchrun \
--nnodes 1 \
--nproc_per_node 2 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:36484 \
scripts/quick.py

Truncated output log

[I debug.cpp:49] [c10d] The debug level is set to DETAIL.
[2024-02-05 19:38:41,000] torch.distributed.run: [INFO] Using nproc_per_node=gpu, setting to 4 since the instance has 40 gpu
[2024-02-05 19:38:41,000] torch.distributed.run: [WARNING] 
[2024-02-05 19:38:41,000] torch.distributed.run: [WARNING] *****************************************
[2024-02-05 19:38:41,000] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-02-05 19:38:41,000] torch.distributed.run: [WARNING] *****************************************
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO] Starting elastic_operator with launch configs:
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   entrypoint       : scripts/quick.py
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   min_nodes        : 1
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   max_nodes        : 1
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   nproc_per_node   : 4
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   run_id           : 23015
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   rdzv_backend     : c10d
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   rdzv_endpoint    : 10.100.172.7:36484
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   rdzv_configs     : {'timeout': 900}
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   max_restarts     : 0
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   monitor_interval : 5
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   log_dir          : None
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO]   metrics_cfg      : {}
[2024-02-05 19:38:41,001] torch.distributed.launcher.api: [INFO] 
[I socket.cpp:480] [c10d - debug] The server socket will attempt to listen on an IPv6 address.
[I socket.cpp:531] [c10d - debug] The server socket is attempting to listen on [::]:36484.
[I socket.cpp:605] [c10d] The server socket has started to listen on [::]:36484.
[I TCPStore.cpp:305] [c10d - debug] The server has started on port = 36484.
[I socket.cpp:720] [c10d - debug] The client socket will attempt to connect to an IPv6 address of (10.100.172.7, 36484).
[I socket.cpp:796] [c10d - trace] The client socket is attempting to connect to [gpu007.atstor.adapt.nccs.nasa.gov]:36484.
[I socket.cpp:299] [c10d - debug] The server socket on [::]:36484 has accepted a connection from [gpu007.atstor.adapt.nccs.nasa.gov]:50424.
[I socket.cpp:884] [c10d] The client socket has connected to [gpu007.atstor.adapt.nccs.nasa.gov]:36484 on [gpu007.atstor.adapt.nccs.nasa.gov]:50424.
[I TCPStore.cpp:342] [c10d - debug] TCP client connected to host 10.100.172.7:36484

...

[rank1]:[I ProcessGroupWrapper.cpp:570] [Rank 1] Running collective: CollectiveFingerPrint(SequenceNumber=3, OpType=ALLREDUCE, TensorShape=[165], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank0]:[I ProcessGroupWrapper.cpp:570] [Rank 0] Running collective: CollectiveFingerPrint(SequenceNumber=3, OpType=ALLREDUCE, TensorShape=[165], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank2]:[I ProcessGroupWrapper.cpp:570] [Rank 2] Running collective: CollectiveFingerPrint(SequenceNumber=3, OpType=ALLREDUCE, TensorShape=[165], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank3]:[I ProcessGroupWrapper.cpp:570] [Rank 3] Running collective: CollectiveFingerPrint(SequenceNumber=3, OpType=ALLREDUCE, TensorShape=[165], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank1]:[I logger.cpp:392] [Rank 1 / 4] [before iteration 1] Training ToyModel unused_parameter_size=0 
 Avg forward compute time: 141051904 
 Avg backward compute time: 289433600 
Avg backward comm. time: 2917376 
 Avg backward comm/comp overlap time: 2158592
[rank1]:[I reducer.cpp:1814] 1 buckets rebuilt with size limits: 1048576 bytes.
[rank3]:[I logger.cpp:392] [Rank 3 / 4] [before iteration 1] Training ToyModel unused_parameter_size=0 
 Avg forward compute time: 141513728 
 Avg backward compute time: 289193984 
Avg backward comm. time: 2116608 
 Avg backward comm/comp overlap time: 1277952
[rank2]:[I logger.cpp:392] [Rank 2 / 4] [before iteration 1] Training ToyModel unused_parameter_size=0 
 Avg forward compute time: 134927360 
 Avg backward compute time: 289390592 
Avg backward comm. time: 2549760 
 Avg backward comm/comp overlap time: 1744896
[rank3]:[I reducer.cpp:1814] 1 buckets rebuilt with size limits: 1048576 bytes.
[rank2]:[I reducer.cpp:1814] 1 buckets rebuilt with size limits: 1048576 bytes.
[rank0]:[I logger.cpp:392] [Rank 0 / 4] [before iteration 1] Training ToyModel unused_parameter_size=0 
 Avg forward compute time: 141504512 
 Avg backward compute time: 289249280 
Avg backward comm. time: 2713600 
 Avg backward comm/comp overlap time: 1920000
[rank0]:[I reducer.cpp:1814] 1 buckets rebuilt with size limits: 1048576 bytes.
[rank3]:[I ProcessGroupWrapper.cpp:570] [Rank 3] Running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank0]:[I ProcessGroupWrapper.cpp:570] [Rank 0] Running collective: CollectiveFingerPrint(SequenceNumber=10, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank2]:[I ProcessGroupWrapper.cpp:570] [Rank 2] Running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank1]:[I ProcessGroupWrapper.cpp:570] [Rank 1] Running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
Traceback (most recent call last):
Traceback (most recent call last):
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 50, in <module>
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 50, in <module>
Traceback (most recent call last):
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 50, in <module>
Traceback (most recent call last):
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 50, in <module>
    demo_basic()
    demo_basic()
    demo_basic()
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 41, in demo_basic
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 41, in demo_basic
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 41, in demo_basic
    outputs = ddp_model(torch.randn(20, 10))    
outputs = ddp_model(torch.randn(20, 10))    
outputs = ddp_model(torch.randn(20, 10))
                                  ^  ^  ^  ^^ ^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
^
^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    demo_basic()
  File "/panfs/ccds02/nobackup/people/golmsche/ml4a/scripts/quick.py", line 41, in demo_basic
    outputs = ddp_model(torch.randn(20, 10))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)

                ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^    ^^^return self._call_impl(*args, **kwargs)^^
^^^^^^^^^^^^^^^^^^^^
^^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return self._call_impl(*args, **kwargs)
    return forward_call(*args, **kwargs)
               return forward_call(*args, **kwargs) 
    ^ ^ ^ ^ ^ ^  ^^ ^^ ^^ ^^ ^^ ^^ ^^ ^^ ^ ^^ ^^^^^^^^    ^^^return forward_call(*args, **kwargs)^
^^^^^^^^^^^^^^^^^^^^^^^^^ ^ ^^^^^ ^^^ ^^^ ^ ^^^ 
^ ^ ^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
 ^ ^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
^^^^^^^^^^^^
^^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    inputs, kwargs = self._pre_forward(*inputs, **kwargs)
        inputs, kwargs = self._pre_forward(*inputs, **kwargs)inputs, kwargs = self._pre_forward(*inputs, **kwargs)    

inputs, kwargs = self._pre_forward(*inputs, **kwargs)
                                                                            ^ ^^  ^^ ^ ^ ^^ ^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^
^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1413, in _pre_forward
^^^  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1413, in _pre_forward
^^^^
^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1413, in _pre_forward
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1413, in _pre_forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
                           if torch.is_grad_enabled() and self.reducer._rebuild_buckets():  
                                                                                ^ ^    ^  ^^ ^^  ^^  ^^  ^ ^ ^ ^ ^^   ^^ ^ ^^  ^^ ^ ^^ ^ ^^  ^^^ ^^^ ^^^^ ^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^
^^RuntimeError^^^^: RuntimeError^^: ^^Detected mismatch between collectives on ranks. Rank 3 is running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=10, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects: 	 Sequence number: 6vs 10Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=10, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects: 	 Sequence number: 10vs 6^^

^^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: RuntimeErrorDetected mismatch between collectives on ranks. Rank 2 is running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=10, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects: 	 Sequence number: 6vs 10: 
Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=10, OpType=BROADCAST, TensorShape=[5], TensorDtypes=Int, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects: 	 Sequence number: 6vs 10
[2024-02-05 19:38:51,274] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1008931) of binary: /explore/nobackup/people/golmsche/.conda/envs/haplo/bin/python3.11
[2024-02-05 19:38:51,284] torch.distributed.elastic.multiprocessing.errors: [INFO] ('local_rank %s FAILED with no error file. Decorate your entrypoint fn with @record for traceback info. See: https://pytorch.org/docs/stable/elastic/errors.html', 0)
Traceback (most recent call last):
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/explore/nobackup/people/golmsche/.conda/envs/haplo/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: