How to load models wrapped with SyncBatchNorm?

  1. Save model which is wrapped with DataParallel().
torch.save(model.module.state_dict(), save_folder + '/' + 'model.pt')
  1. Load this model & Train with SyncBatchNorm + DDP.
# Define Model
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu], find_unused_parameters=True)
# Load Model
loc = 'cuda:{}'.format(args.gpu)
net.module.load_state_dict(torch.load(save_folder + '/model.pt', map_location=loc), strict=False)

Error(s) in loading stage_dict for DistributedDataParallel:
Missing key(s) in state_dict: "module.con1.weight", "module.bn1.weight", ...

How can I load my models trained with DataParallel() after warp with SyncBatchNorm + DDP ?

just do like this:

Define Model

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(args.gpu)
model_DDP = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)

save and load Model

torch.save(model_DDP, tmp.name)
model_DDP = torch.load(tmp.name)