I’m using DDP(one process per GPU) to training a 3D UNet. I transfered all batchnorm layer inside network to syncbatchnorm with nn.SyncBatchNorm.convert_sync_batchnorm.
When doing validation at the end of every training epoch on rank 0, it always freeze at same validation steps. I think it is because of the syncbatchnorm layer. What is the correct way to do validation when DDP model has syncbatchnorm layer? Should I do validation on all ranks?
Code
for epoch in range(epochs):
model.train()
train_loader.sampler.set_epoch(epoch)
for step, (data, target) in enumerate(train_loader):
# ...training codes
train(model)
if dist.get_rank() == 0:
# ...validation codes
model.eval()
validate(model)
dist.barrier()
Could you update to the latest stable release or the nightly binary and check, if you are still facing the error? 1.1.0 is quite old by now and this issue might have been already fixed.
Yes, you probably need to do validation on all ranks since SyncBatchNorm has collectives which are expected to run on all ranks. The validation is probably getting stuck since SyncBatchNorm on rank 0 is waiting for collectives from other ranks.
Another option is to convert the SyncBatchNorm layer to a regular BatchNorm layer and then do the validation on a single rank.
Thanks I will try it.
Actually I have another question about v1.1.0 DDP.
I tried to inference the model with syncbatchnorm layer ( Actually, it becomes batchnorm layer after load from checkpoint ). The results turned to be different between:
Only turn on evaluate mode.
model.eval()
# inference...
Manually set track_running_stats of each BN layer to False after model.eval().
model.eval()
set_BN_track_running_stats(model, False)
# do inference..
It is strange that the second one is much better than first one on early epochs. Is this also a version problem?
P.S. However, after more epochs training, results of two inference method are similiar but still has small differences.
Below is sample code of set_BN_track_running_stats():
def set_BN_track_running_stats(module, ifTrack=True):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.track_running_stats = ifTrack
for child in module.children():
set_BN_track_running_stats(module = child, ifTrack=ifTrack)
Are you referring to PyTorch v1.1.0 here? If so, I’d suggest upgrading to PyTorch 1.7 which is the latest version to see if this problem still persists.