It works! Thank you
I used:
model = resnet50()
model.to(device)
model.load_state_dict(torch.load(cfg.MODEL.pretrained_model_path))
and I got a Runtime error:
RuntimeError: Error(s) in loading state_dict for Resnet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight" ... ...
Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight" ... ...
Seems the Distributed DataParallel save the model in module
. Then I find a solution there:
state_dict = torch.load(cfg.MODEL.pretrained_model_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)