How to use "break" in DistributedDataParallel training

I am using DistributedDataParallel to train the model on multiple GPUs. If I would like to stop the process early, how could I achieve it? Thanks.

Is this about uneven inputs on different processes? See:

  1. https://github.com/pytorch/pytorch/issues/33148
  2. https://github.com/pytorch/pytorch/issues/38174

If all processes know when to exit, simply break the loop would work. The tricky case is when one processes breaks the loop but other processes proceed as mentioned in the above two issues.

Indeed this is what I meet. one process breaks the loop while others continue. The condition when the process breaks is the loss in eval dataset increases (overfitting). Do you have any ideas? Thanks.

Ideally, we should address this in DDP and close https://github.com/pytorch/pytorch/issues/38174. Before that takes place, you can use all_reduce synchronize some signal across all processes. See Multiprocessing - Barrier Blocks all Processes?

One thing to note is that, this might have perf impacts, especially when the model is light and its forward pass runs faster than communicating the signal.

Thanks for your help. Probably I would set a fixed epoch number to address this, which is simple thought is not optimal.

flag_tensor = torch.zeros(1).to(device)
if local_rank = 0:
----------#Conditions for breaking the loop :
-------------------flag_tensor += 1
dist.all_reduce(flag_tensor,op=ReduceOp.SUM)
if flag_tensor == 1:
----------print(“Training stopped”)
----------break

1 Like

Thank you for your simple yet elegant solution. It worked. I wanted my training script to end when meeting an early stopping criterion on the validation loss, and now all the processes are correctly ended.

1 Like