I trained a model like this:
net = resnet50(num_classes=settings.NUM_CLASSES)
net = nn.parallel.DataParallel(net, device_ids=list(range(torch.cuda.device_count())))
net = net.cuda()
...
torch.save(net.state_dict(), weights_path)
I would like to load the model:
model = get_network(args)
model.load_state_dict(torch.load(settings.CHECKPOINT_PATH))
model = nn.parallel.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model.cuda()
model.eval()
I get this error:
Traceback (most recent call last):
File "attack.py", line 53, in <module>
model.load_state_dict(torch.load(settings.CHECKPOINT_PATH))
File "/home/user/.conda/envs/bm/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.0.weight", "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "conv2_x.0.residual_function.0.weight", "conv2_x.0.residual_function.1.weight", "conv2_x.0.residual_function.1.bias", "conv2_x.0.residual_function.1.running_mean", "conv2_x.0.residu