Best way to handle throw_on_early_termination exception for models with sync_bn?

When trying to train a model in DDP which has a sync_batch_norm layer with uneven inputs, the suggestion is to use the parameter throw_on_early_termination as True within the join() context manager of DDP. I wanted to know what is the best way to handle the exception raised by the join() context manager in case one rank exhausts all its input and we just want the training to continue normally for next iterations? Should we just swallow the exception and move forward? Should we put a CUDA synchronize call as part of the exception handling? Anything else?

we just want the training to continue normally for next iterations

Assuming you mean continue training with DDP, this exception basically indicates this is not possible - one rank has finished all its inputs while others have not. As a result, if you continue training DDP will hang (probably at SyncBN step) and the exception is designed to indicate this.

The main use case of the exception is to use it as a “signal” to finish the training process. i.e. all processes will raise this exception, and application code can catch it and terminate the main training loop, saving/evaluating the trained model appropriately. Usually inputs are only off by a few examples across ranks, so it should be fine to terminate training at this stage - if this isn’t true then you may want to look into how to better balance the dataset across ranks.

Got it, thanks for your reply!