How to handle the OOM exception when training with DDP and syncBatchNorm?

When training with the DDP and syncBatchNorm, one process runing on one GPU, When I catch the gpu OOM exception, the training is blocked. What should we do?
My code is following, when OOM exception occurs in one process, I just ignore this batch, the training phase continue.

for i, (inputs, targets) in enumerate(train_loader):
    try:
        # do forward and backprop
    except RuntimeError as e:
        if 'out of memory' in str(e):
            print('| WARNING: ran out of memory, skipping this batch.')
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            optimizer.zero_grad()
        else:
            raise e

when one process catch the exception, the others get blocked.

1 Like