Hi everyone,
I have seen similar topics to this one however I don’t haven’t been able to work out a solution from the posts on this board. Thanks in advance for your patience!
I’m training a model on a virtual instance on AWS instance:
Deep Learning AMI (Ubuntu) Version 20.0 (ami-0d0ff0945ae093aea)
I know that it is GPU enabled so i can train on CUDA:
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU ...')
else:
print('CUDA is available! Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Making sure it’s on CUDA before training:
if train_on_gpu:
model.cuda()
cudnn.benchmark = True
Saving the model:
torch.save({'arch': 'vgg19',
'state_dict': model.state_dict(),
'class_to_idx': model.class_to_idx},
'classifier.pth')
I made a function that loads the model:
def load_model(checkpoint_path):
chpt = torch.load(checkpoint_path)
if chpt['arch'] == 'vgg19':
model = models.vgg19(pretrained=True)
for param in model.parameters():
param.requires_grad = False
model.class_to_idx = chpt['class_to_idx']
model.cpu()
# Create the classifier
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(25088, 4096)),
('relu', nn.ReLU()),
('fc2', nn.Linear(4096, 102)),
('output', nn.LogSoftmax(dim=1))
]))
# Put the classifier on the pretrained network
model.classifier = classifier
model.load_state_dict(chpt['state_dict'])
return model
loading the mode:
model = load_model('vgg19.pth')
The error below is perplexing. The missing keys and the unexpected keys are the same with the addition of module. in front.
How can this be?
RuntimeError: Error(s) in loading state_dict for VGG:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.16.weight", "features.16.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.23.weight", "features.23.bias", "features.25.weight", "features.25.bias", "features.28.weight", "features.28.bias", "features.30.weight", "features.30.bias", "features.32.weight", "features.32.bias", "features.34.weight", "features.34.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias".
Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.16.weight", "module.features.16.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.23.weight", "module.features.23.bias", "module.features.25.weight", "module.features.25.bias", "module.features.28.weight", "module.features.28.bias", "module.features.30.weight", "module.features.30.bias", "module.features.32.weight", "module.features.32.bias", "module.features.34.weight", "module.features.34.bias", "module.classifier.fc1.weight", "module.classifier.fc1.bias", "module.classifier.fc2.weight", "module.classifier.fc2.bias".