Here is the start and end the .th file that my weights are stored in :
PATH = './pytorch_resnet_cifar10/pretrained_models/resnet32-d509ac18.th'
net = resnet32().to(device)
print(torch.load(PATH))
Output:
{'state_dict': OrderedDict([('module.conv1.weight', tensor([[[[ 4.0355e-01, -5.4608e-01, -6.8499e-01],
[-1.0668e-01, -9.9667e-01, -1.0748e+00],
[ 1.0510e-01, -7.8233e-01, -4.1935e-01]],
[[ 3.7818e-01, 7.3026e-01, 2.7383e-01],
[ 2.3540e-01, 1.0729e+00, 6.1863e-01],
[ 1.7275e-01, 2.7025e-01, 2.5061e-01]],
...
[-0.6881, -0.7202, -0.4475, -0.3835, 0.2615, 0.6608, -0.0996, -0.0243,
-0.3856, -0.0648, -0.2151, -0.2414, -0.1699, -0.0557, -0.2486, 1.1742,
-0.6238, -0.2589, -0.4804, -0.3014, -0.4141, -0.1574, -0.3830, -0.1402,
0.3194, -0.3767, 0.0905, 0.4900, -0.2189, -0.0690, -0.2826, 0.3625,
0.8595, -0.0683, -0.1282, -0.2713, -0.3498, -0.5063, -0.0771, 0.5548,
0.2186, -0.1179, 0.4577, -0.2268, 0.1091, 0.4121, 0.7449, -0.1425,
-0.4718, -0.0647, 1.3254, 0.9775, 0.7825, -0.3761, -0.3092, -0.5736,
0.6093, -0.2387, 0.1829, -0.0349, 0.3537, -0.2790, 0.9631, -0.3342]],
device='cuda:0')), ('module.linear.bias', tensor([ 0.1763, -0.3967, 0.0932, 0.1306, 0.2050, -0.1249, 0.1557, -0.1295,
-0.0172, -0.0863], device='cuda:0'))]), 'best_prec1': 92.78000144958496}
You can see at the beginning and end it has ‘state_dict’ and ‘best_prec1’, both of which are mentioned in the below error thrown when trying to load the weights via: net.load_state_dict(torch.load(PATH))
output:
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight",
...
"layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "linear.weight", "linear.bias".
Unexpected key(s) in state_dict: "state_dict", "best_prec1".
Any ideas on why PyTorch isn’t liking this state_dict?