Hi community,
I trained my model on Azure in the past with nn.DataParallel
and I have been able in my local to load it in plain mode by removing the module.
prefix in the keys of the state_dict
using the following code.
This same code is raising errors in a t2.micro with the Deep Learning Amazon AMI (errors shown as comments in the code)
# Code to load a trained model from a certain epoch saved as model-epoch.pkl
from collections import OrderedDict
def get_epoch(pth):
pth = pth.split('/')[-1] # remove all the path
print('Loading dict: ', pth)
pth = pth[:-4] # remove .pkl
epoch = pth.split('-')[1] # get just the epoch
print('Epoch to restart training: ', epoch)
return epoch
# Output from AWS:
# Loading dict: ResNet20_2-149.pkl
# Epoch to restart training: 149
def load_weights(path, verbose=0):
global device
state_dict = torch.load(path, map_location=device)
if verbose == 1: print('Current dict: ', state_dict.keys())
new_state_dict = OrderedDict()
for k,v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v
if verbose == 1: print('New dict: ', new_state_dict.keys())
return new_state_dict
# Output from AWS: (same as in my Local!)
# Current dict: odict_keys(['module.conv.weight', 'module.bn.weight', ....
# New dict: odict_keys(['conv.weight', 'bn.weight', ...
print('Loading trained model... ')
# Load saved models
e_epoch = get_epoch(pth)
assert os.path.exists(pth), 'Model to load not found'
ps = glob.glob(os.path.join(pth, '*.pkl'))
print('Getting ready Model : ', singleModel)
# Output from AWS (same as my local):
'''
Loading trained models...
Getting ready Single Model : ResNet(
(conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm ...
'''
s_epoch = int(get_epoch(ps[0]))
singleModel.load_state_dict(load_weights(ps[0], verbose=1))
print('[OK] Single model loaded on epoch ', s_epoch)
What can be happening that I am able to load the dictionary with no errors on my local but in AWS it returns:
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "layer1.3.conv1.weight", "layer1.3.conv2.weight", "layer1.3.bn1.weight", "layer1.3.bn1.bias", "layer1.3.bn1.running_mean", "layer1.3.bn1.running_var", "layer1.3.bn2.weight", "layer1.3.bn2.bias", "layer1.3.bn2.running_mean ...
Thank you in advance!