Understanding the `torch.distributed.elastic.multiprocessing.errors.ChildFailedError` error

Hi,
I’m debugging a DDP script launched via torchrun --nproc_per_node=2 train.py. The model is wrapped in the following way:

from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
            model,
            device_ids=[args.local_rank] if args.use_cuda else None,
            output_device=args.local_rank if args.use_cuda else None,
        )

The code works on a single device.
Here is the log I obtained by setting TORCH_DISTRIBUTED_DEBUG=DETAIL:

Original exception was:
Traceback (most recent call last):
  File "/home/user/train.py", line 466, in <module>
    main()
  File "/home/user/train.py", line 207, in main
    train_losses = loader_forward(
  File "/home/user/train.py", line 361, in loader_forward
    seq_losses = batched_seq_forward(
  File "/home/user/train.py", line 443, in batched_seq_forward
    loss.backward()
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: ProcessGroupWrapper: Monitored Barrier encountered error running collective: CollectiveFingerPrint(SequenceNumber=6, OpType=ALLREDUCE, TensorShape=[13209847], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))). Error: 
[../third_party/gloo/gloo/transport/tcp/pair.cc:534] Connection closed by peer [127.0.0.1]:36491
[2024-11-10 23:15:10,476] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 2671122) of binary: /opt/miniconda3/bin/python
Traceback (most recent call last):
  File "/home/user/.local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/miniconda3/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-11-10_23:15:10
  host      : coeus
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 2671123)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-11-10_23:15:10
  host      : coeus
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 2671122)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================ 

The script works when using CPU and the gloo backend (rather than CUDA with nccl).
The issue originates in the loss.backward() call, but the error message is not clear to me. Could you help to understand the problem?

The underlying problem was due to a bug in the model. For debugging, two things were useful:

  • setting two flags when calling torchrun: CUDA_LAUNCH_BLOCKING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL
  • decorating the main() with record from from torch.distributed.elastic.multiprocessing.errors import record as pointed out in Error Propagation — PyTorch 2.5 documentation