- Save model which is wrapped with
DataParallel()
.
torch.save(model.module.state_dict(), save_folder + '/' + 'model.pt')
- Load this model & Train with
SyncBatchNorm + DDP
.
# Define Model
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu], find_unused_parameters=True)
# Load Model
loc = 'cuda:{}'.format(args.gpu)
net.module.load_state_dict(torch.load(save_folder + '/model.pt', map_location=loc), strict=False)
Error(s) in loading stage_dict for DistributedDataParallel:
Missing key(s) in state_dict: "module.con1.weight", "module.bn1.weight", ...
How can I load my models trained with DataParallel()
after warp with SyncBatchNorm + DDP
?