NCCL Timeout when fail one node - Torchrun Elastic

Hi, I´m running two nodes (on the same machine) where each one has an assigned gpu. I´m testing the scale-down and when I fail one node with SIGTERM signal I get hung the other node. After a few minutes, the NCCL watchdog throws the following error:

[rank0]:[E ProcessGroupNCCL.cpp:563] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3917, OpType=ALLREDUCE, NumelIn=513000, NumelOut=513000, Timeout(ms)=600000) ran for 600016 milliseconds before timing out.
[rank0]:[E ProcessGroupNCCL.cpp:1537] [PG 0 Rank 0] Timeout at NCCL work: 3917, last enqueued NCCL work: 3928, last completed NCCL work: 3916.
[rank0]:[E ProcessGroupNCCL.cpp:577] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E ProcessGroupNCCL.cpp:583] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3917, OpType=ALLREDUCE, NumelIn=513000, NumelOut=513000, Timeout(ms)=600000) ran for 600016 milliseconds before timing out.
Exception raised from checkTimeout at …/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:565 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f6880ecf897 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f6834a68c62 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1a0 (0x7f6834a6da80 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f6834a6edcc in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4: + 0xd44a3 (0x7f68804d44a3 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #5: + 0x89134 (0x7f6881d2f134 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: + 0x1097dc (0x7f6881daf7dc in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of ‘c10::DistBackendError’
what(): [PG 0 Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=3917, OpType=ALLREDUCE, NumelIn=513000, NumelOut=513000, Timeout(ms)=600000) ran for 600016 milliseconds before timing out.
Exception raised from checkTimeout at …/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:565 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f6880ecf897 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f6834a68c62 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1a0 (0x7f6834a6da80 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f6834a6edcc in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4: + 0xd44a3 (0x7f68804d44a3 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #5: + 0x89134 (0x7f6881d2f134 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: + 0x1097dc (0x7f6881daf7dc in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at …/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1418 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f6880ecf897 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: + 0xe32119 (0x7f68346f2119 in /home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #2: + 0xd44a3 (0x7f68804d44a3 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: + 0x89134 (0x7f6881d2f134 in /lib/x86_64-linux-gnu/libc.so.6)
frame #4: + 0x1097dc (0x7f6881daf7dc in /lib/x86_64-linux-gnu/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007f6763bd36c0 (most recent call first):

Thread 0x00007f676dbd76c0 (most recent call first):

Thread 0x00007f6881ca52c0 (most recent call first):
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/autograd/graph.py”, line 744 in _engine_run_backward
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/autograd/init.py”, line 267 in backward
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/_tensor.py”, line 525 in backward
File “/home/llibutti/pytorch_env/torchx_examples/resnet_elastic1.py”, line 82 in _run_batch
File “/home/llibutti/pytorch_env/torchx_examples/resnet_elastic1.py”, line 97 in _run_epoch
File “/home/llibutti/pytorch_env/torchx_examples/resnet_elastic1.py”, line 110 in train
File “/home/llibutti/pytorch_env/torchx_examples/resnet_elastic1.py”, line 148 in main
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py”, line 347 in wrapper
File “/home/llibutti/pytorch_env/torchx_examples/resnet_elastic1.py”, line 160 in

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, PIL._imaging, PIL._imagingft (total: 22)
E0708 16:00:27.766000 140702371304128 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: -6) local_rank: 0 (pid: 977782) of binary: /home/llibutti/pytorch_env/pytorch_env/bin/python3
Traceback (most recent call last):
File “/home/llibutti/pytorch_env/pytorch_env/bin/torchrun”, line 8, in
sys.exit(main())
^^^^^^
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py”, line 347, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/run.py”, line 879, in main
run(args)
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/run.py”, line 870, in run
elastic_launch(
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/launcher/api.py”, line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/llibutti/pytorch_env/pytorch_env/lib/python3.11/site-packages/torch/distributed/launcher/api.py”, line 263, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

resnet_elastic1.py FAILED

Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2024-07-08_16:00:27
host : marbore.dacya.ucm.es
rank : 0 (local_rank: 0)
exitcode : -6 (pid: 977782)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 977782

The commands launched to execute are as follows:

torchrun --nnodes=1:2 --nproc-per-node=1 --rdzv_id=22 --rdzv_backend=etcd-v2 --rdzv_endpoint=localhost:2379 --node_rank=0 --rdzv-conf=is_host=1 resnet_elastic1.py 10 1

torchrun --nnodes=1:2 --nproc-per-node=1 --rdzv_id=22 --rdzv_backend=etcd-v2 --rdzv_endpoint=localhost:2379 --node_rank=1 --rdzv-conf=is_host=0 resnet_elastic1.py 10 1

I understand that the problem is in the ALLREDUCE of the backward (in DDP model) that is not performed because a node failed. How do I restart torchrun so that it uses only alive node at that point?