Validation hangs up when using DDP and syncbatchnorm

I’m using DDP(one process per GPU) to training a 3D UNet. I transfered all batchnorm layer inside network to syncbatchnorm with nn.SyncBatchNorm.convert_sync_batchnorm.

When doing validation at the end of every training epoch on rank 0, it always freeze at same validation steps. I think it is because of the syncbatchnorm layer. What is the correct way to do validation when DDP model has syncbatchnorm layer? Should I do validation on all ranks?


for epoch in range(epochs):
    for step, (data,  target) in enumerate(train_loader):
        # codes
    if dist.get_rank() == 0:
        # ...validation codes

Version / os

torch = 1.1.0
ubuntu 18.04
distributed backend: nccl

Similar question:


Could you update to the latest stable release or the nightly binary and check, if you are still facing the error? 1.1.0 is quite old by now and this issue might have been already fixed.

1 Like

Yes, you probably need to do validation on all ranks since SyncBatchNorm has collectives which are expected to run on all ranks. The validation is probably getting stuck since SyncBatchNorm on rank 0 is waiting for collectives from other ranks.

Another option is to convert the SyncBatchNorm layer to a regular BatchNorm layer and then do the validation on a single rank.

1 Like

Thanks I will try it.
Actually I have another question about v1.1.0 DDP.
I tried to inference the model with syncbatchnorm layer ( Actually, it becomes batchnorm layer after load from checkpoint ). The results turned to be different between:

  1. Only turn on evaluate mode.
# inference...
  1. Manually set track_running_stats of each BN layer to False after model.eval().
set_BN_track_running_stats(model, False)
# do inference..

It is strange that the second one is much better than first one on early epochs. Is this also a version problem?

P.S. However, after more epochs training, results of two inference method are similiar but still has small differences.

Below is sample code of set_BN_track_running_stats():

def set_BN_track_running_stats(module, ifTrack=True):
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.track_running_stats = ifTrack
    for child in module.children():
        set_BN_track_running_stats(module = child, ifTrack=ifTrack)

Yes. Validate on all GPUs works. But shouldn’t model.eval() aotomatically disable sync process?

BTW, thanks for your reply.

Are you referring to PyTorch v1.1.0 here? If so, I’d suggest upgrading to PyTorch 1.7 which is the latest version to see if this problem still persists.

Good point, I’ve opened an issue for this:

1 Like

Thanks greatly for your reply. I will try v1.7.

I upgrade torch to 1.7 today. The problem is gone. I think it is a 1.1.0 problem. Thanks again for you help.

Actually. I met another problem after I upgrade to V1.7.0. The result cames to be much worse than it on 1.1. Could you help me with that?