Set early stopping critera using DDP


I would like to set an early stopping criteria in my DDP model. I have a single node with 8 GPUs, and am training using DDP and a DistributedDataSampler, using torch.distributed.launch. I’m implementing the early stopping criteria as follows:

early_stop = torch.zeros(1, device=local_rank)
if local_rank == 0:
     # get current loss on masked and non-masked validation tokens
     loss, loss_missing = logger.loss()

     # stop_value is a boolean flag indicating whether the stopping criteria has been met
     stop_value = logger.step(ddp_model, loss_missing)
     stop_value = torch.tensor(stop_value)
     early_stop = early_stop + stop_value

     # synchronize `early_stop` across all devices
     # if any `early_stop` equals 1 after synchronization, training should stop
     dist.all_reduce(early_stop, op=dist.ReduceOp.SUM)

if early_stop != 0:
     print(f'Breaking at epoch: {epoch}, rank: {local_rank}')

However, I can’t seem to get this to run without generating CUDA errors. My training scripts works when commenting out all instances of dist.all_reduce. However, that defeats the purpose of trying to implement early stopping.

Here is the generated error, after applying export CUDA_LAUNCH_BLOCKING=1:

Traceback (most recent call last):
  File "/home/krieschenburg/code/jds-abDND/bin/", line 380, in <module>
    main(args.local_world_size, args.local_rank, args)
  File "/home/krieschenburg/code/jds-abDND/bin/", line 242, in main
    train(args.local_world_size, args.local_rank, args)
  File "/home/krieschenburg/code/jds-abDND/bin/", line 113, in train
    dist.all_reduce(early_stop, op=dist.ReduceOp.SUM)
  File "/home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/distributed/", line 1320, in all_reduce
    work = default_pg.allreduce([tensor], opts)
RuntimeError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:46, unhandled cuda error, NCCL version 2.10.3
ncclUnhandledCudaError: Call to CUDA function failed.
[W CUDAGuardImpl.h:113] Warning: CUDA warning: the launch timed out and was terminated (function destroyEvent)
terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: the launch timed out and was terminated
Exception raised from create_event_internal at ../c10/cuda/CUDACachingAllocator.cpp:1387 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f36553e6612 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #1: <unknown function> + 0x22c1e (0x7f3655655c1e in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x22d (0x7f3655658c4d in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #3: <unknown function> + 0x339668 (0x7f369ec6e668 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #4: c10::TensorImpl::release_resources() + 0x175 (0x7f36553cb295 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #5: <unknown function> + 0x11d6fdd (0x7f3686c79fdd in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #6: c10d::Reducer::~Reducer() + 0x254 (0x7f3689a4c8c4 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #7: std::_Sp_counted_ptr<c10d::Reducer*, (__gnu_cxx::_Lock_policy)2>::_M_dispose() + 0x12 (0x7f369f1b4192 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #8: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0x46 (0x7f369eb49ca6 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #9: <unknown function> + 0x8827af (0x7f369f1b77af in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #10: <unknown function> + 0x21c2c0 (0x7f369eb512c0 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #11: <unknown function> + 0x21d46e (0x7f369eb5246e in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/lib/python3.8/site-packages/torch/lib/
frame #12: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x5d41f4]
frame #13: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x5a929d]
frame #14: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x5ee8c0]
frame #15: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x544e48]
frame #16: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x544e9a]
frame #17: /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python() [0x544e9a]
frame #18: PyDict_SetItemString + 0x538 (0x5d3048 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)
frame #19: PyImport_Cleanup + 0x79 (0x687ad9 in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)
frame #20: Py_FinalizeEx + 0x7f (0x682aef in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)
frame #21: Py_RunMain + 0x32d (0x6b9e4d in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)
frame #22: Py_BytesMain + 0x2d (0x6ba0bd in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)
frame #23: __libc_start_main + 0xf3 (0x7f36a6e300b3 in /lib/x86_64-linux-gnu/
frame #24: _start + 0x2e (0x5fc5fe in /home/krieschenburg/.local/share/virtualenvs/jds-abDND-Yct67OkM/bin/python)

I’ve found various forum and Github posts (1), (2), (3), but can’t seem to fix this issue.

Any help is much appreciated.


Why do you have this if condition? In other words, why are you only computing stop_value on local rank 0? This seems to differ from your comment that if any early_stop equals 1 then training should stop. In your 1 node 8 GPU setup, there should only be on rank 0.

The immediate cause of your hang seems to be because you call dist.all_reduce() only on rank 0 when it should be called on all ranks.


My mistake – I typed my code incorrectly. The dist.all_reduce is in fact called outside of the if local_rank == 0 condition.

That being said, using the if condition to update the early_stop value only on rank(0) shouldn’t impact any of the actual dist.all_reduce() calls. If the rank(0) device doesn’t meet the early_stop criteria (e.g. if early_stop on rank(0) is still 0), the model should simply keep training, since the dist.ReduceOp.SUM call should set all early_stop values across all devices to 0 – instead, I’m seeing CUDA errors.

Would you propose a different approach to this?