Synchronize some information between processes at the end of each epoch (use case: setting time limit for training)

I have a training script which I launch using torch.distributed.launch on multiple GPUs. I would like to set a time limit, so that my training will early stop without surpassing this limit. Something like this:

# for storing the running times of the last 3 epochs
epoch_time_queue = deque(maxlen=3)

start_time = time.time()
for epoch in range(start_epoch, args.epochs):
    start_epoch_time = time.time()
    # training
    train_epoch(...)
    # validation
    eval_epoch(...)
    # epoch time in minutes
    epoch_time = (time.time() - start_epoch_time)/60
    # average duration of the last 3 epochs
    epoch_time_queue.append(epoch_time)
    avg_time = sum(epoch_time_queue)/len(epoch_time_queue)
    # if the next epoch will likely surpass the time limit, then stop here
    estimated_next_total_time = (time.time() - start_time)/60 + avg_time
    if args.time_limit > 0 and estimated_next_total_time > args.time_limit:
        break

The issue is that the elapsed time may be different between processes. For example, at the end of the 5th epoch, the process on GPU1 may think that it will surpass the time limit at the next (6th) epoch by a few seconds, so it stops; while GPU2 thinks that it will be able to finish the 6th epoch a few seconds before the limit, so it will continue, which is not good.

I would like to know if there is a way for the processes to communicate about this. Ideally, a process should wait for all the other processes to finish the current epoch to decide whether to go for the next epoch or not.

@f10w yep, this is possible, you can use the all_gather API to let every process to collect elapsed time from all processes.

1 Like

Thanks for the reply!
When we do torch.distributed.all_gather(), does it create some kind of “barrier” between the processes?

Yes, it does, all collective communications (e.g., boradcast, all_reduce, all_gather, etc.) can be considered as a barrier.

1 Like