I find most people first convert BatchNorm
to SyncBatchNorm
and then wrap the model
with DistributedDataParallel
:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
If I reverse the order like the following, would I get the same results?
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)