#My code:```
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torchvision import models
from collections import OrderedDict
structures = {“vgg16”:25088,
“densenet121” : 1024,
“alexnet” : 9216 }
def Network(structure=‘vgg16’,dropout=0.5, hidden_layer1 = 120,lr = 0.001):
if structure == ‘vgg16’:
model = models.vgg16(pretrained=True)
elif structure == ‘densenet121’:
model = models.densenet121(pretrained=True)
elif structure == ‘alexnet’:
model = models.alexnet(pretrained = True)
else:
print("{} is not a valid model.Please pass vgg16,densenet121,or alexnet".format(structure))
for param in model.parameters():
param.requires_grad = False
classifier = nn.Sequential(OrderedDict([
('dropout',nn.Dropout(dropout)),
('inputs', nn.Linear(structures[structure], hidden_layer1)),
('relu1', nn.ReLU()),
('hidden_layer1', nn.Linear(hidden_layer1, 90)),
('relu2',nn.ReLU()),
('hidden_layer2',nn.Linear(90,80)),
('relu3',nn.ReLU()),
('hidden_layer3',nn.Linear(80,102)),
('output', nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr )
return model , optimizer ,criterion
model,optimizer,criterion = Network(‘densenet121’)
TODO: Write a function that loads a checkpoint and rebuilds the model
def load_model(path):
checkpoint = torch.load(’/home/workspace/model_classifier.pth’, map_location=‘cpu’)
structure = checkpoint[‘structure’]
hidden_layer1 = checkpoint[‘hidden_layer1’]
model,, = Network(structure , 0.5,hidden_layer1)
model.class_to_idx = checkpoint[‘class_to_idx’]
model.load_state_dict(checkpoint[‘state_dict’])
return model
Load your model to this variable
model = load_model(’/home/workspace/model_classifier.pth’)
If you used something other than 224x224 cropped images, set the correct size here
image_size = 224
Values you used for normalizing the images. Default here are for
pretrained models from torchvision.
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
#Error:
RuntimeError: Error(s) in loading state_dict for DenseNet:
Unexpected key(s) in state_dict: "features.norm0.num_batches_tracked", "features.denseblock1.denselayer1.norm1.num_batches_tracked", "features.denseblock1.denselayer1.norm2.num_batches_tracked", "features.denseblock1.denselayer2.norm1.num_batches_tracked", "features.denseblock1.denselayer2.norm2.num_batches_tracked", "features.denseblock1.denselayer3.norm1.num_batches_tracked", "features.denseblock1.denselayer3.norm2.num_batches_tracked", "features.denseblock1.denselayer4.norm1.num_batches_tracked", "features.denseblock1.denselayer4.norm2.num_batches_tracked", "features.denseblock1.denselayer5.norm1.num_batches_tracked", "features.denseblock1.denselayer5.norm2.num_batches_tracked", "features.denseblock1.denselayer6.norm1.num_batches_tracked", "features.denseblock1.denselayer6.norm2.num_batches_tracked", "features.transition1.norm.num_batches_tracked", "features.denseblock2.denselayer1.norm1.num_batches_tracked", "features.denseblock2.denselayer1.norm2.num_batches_tracked", "features.denseblock2.denselayer2.norm1.num_batches_tracked", "features.denseblock2.denselayer2.norm2.num_batches_tracked", "features.denseblock2.denselayer3.norm1.num_batches_tracked", "features.denseblock2.denselayer3.norm2.num_batches_tracked", "features.denseblock2.denselayer4.norm1.num_batches_tracked", "features.denseblock2.denselayer4.norm2.num_batches_tracked", "features.denseblock2.denselayer5.norm1.num_batches_tracked", "features.denseblock2.denselayer5.norm2.num_batches_tracked", "features.denseblock2.denselayer6.norm1.num_batches_tracked", "features.denseblock2.denselayer6.norm2.num_batches_tracked", "features.denseblock2.denselayer7.norm1.num_batches_tracked", "features.denseblock2.denselayer7.norm2.num_batches_tracked", "features.denseblock2.denselayer8.norm1.num_batches_tracked", "features.denseblock2.denselayer8.norm2.num_batches_tracked", "features.denseblock2.denselayer9.norm1.num_batches_tracked", "features.denseblock2.denselayer9.norm2.num_batches_tracked", "features.denseblock2.denselayer10.norm1.num_batches_tracked", "features.denseblock2.denselayer10.norm2.num_batches_tracked", "features.denseblock2.denselayer11.norm1.num_batches_tracked", "features.denseblock2.denselayer11.norm2.num_batches_tracked", "features.denseblock2.denselayer12.norm1.num_batches_tracked", "features.denseblock2.denselayer12.norm2.num_batches_tracked", "features.transition2.norm.num_batches_tracked", "features.denseblock3.denselayer1.norm1.num_batches_tracked", "features.denseblock3.denselayer1.norm2.num_batches_tracked", "features.denseblock3.denselayer2.norm1.num_batches_tracked", "features.denseblock3.denselayer2.norm2.num_batches_tracked", "features.denseblock3.denselayer3.norm1.num_batches_tracked", "features.denseblock3.denselayer3.norm2.num_batches_tracked", "features.denseblock3.denselayer4.norm1.num_batches_tracked", "features.denseblock3.denselayer4.norm2.num_batches_tracked", "features.denseblock3.denselayer5.norm1.num_batches_tracked", "features.denseblock3.denselayer5.norm2.num_batches_tracked", "features.denseblock3.denselayer6.norm1.num_batches_tracked", "features.denseblock3.denselayer6.norm2.num_batches_tracked", "features.denseblock3.denselayer7.norm1.num_batches_tracked", "features.denseblock3.denselayer7.norm2.num_batches_tracked", "features.denseblock3.denselayer8.norm1.num_batches_tracked", "features.denseblock3.denselayer8.norm2.num_batches_tracked", "features.denseblock3.denselayer9.norm1.num_batches_tracked", "features.denseblock3.denselayer9.norm2.num_batches_tracked", "features.denseblock3.denselayer10.norm1.num_batches_tracked", "features.denseblock3.denselayer10.norm2.num_batches_tracked", "features.denseblock3.denselayer11.norm1.num_batches_tracked", "features.denseblock3.denselayer11.norm2.num_batches_tracked", "features.denseblock3.denselayer12.norm1.num_batches_tracked", "features.denseblock3.denselayer12.norm2.num_batches_tracked", "features.denseblock3.denselayer13.norm1.num_batches_tracked", "features.denseblock3.denselayer13.norm2.num_batches_tracked", "features.denseblock3.denselayer14.norm1.num_batches_tracked", "features.denseblock3.denselayer14.norm2.num_batches_tracked", "features.denseblock3.denselayer15.norm1.num_batches_tracked", "features.denseblock3.denselayer15.norm2.num_batches_tracked", "features.denseblock3.denselayer16.norm1.num_batches_tracked", "features.denseblock3.denselayer16.norm2.num_batches_tracked", "features.denseblock3.denselayer17.norm1.num_batches_tracked", "features.denseblock3.denselayer17.norm2.num_batches_tracked", "features.denseblock3.denselayer18.norm1.num_batches_tracked", "features.denseblock3.denselayer18.norm2.num_batches_tracked", "features.denseblock3.denselayer19.norm1.num_batches_tracked", "features.denseblock3.denselayer19.norm2.num_batches_tracked", "features.denseblock3.denselayer20.norm1.num_batches_tracked", "features.denseblock3.denselayer20.norm2.num_batches_tracked", "features.denseblock3.denselayer21.norm1.num_batches_tracked", "features.denseblock3.denselayer21.norm2.num_batches_tracked", "features.denseblock3.denselayer22.norm1.num_batches_tracked", "features.denseblock3.denselayer22.norm2.num_batches_tracked", "features.denseblock3.denselayer23.norm1.num_batches_tracked", "features.denseblock3.denselayer23.norm2.num_batches_tracked", "features.denseblock3.denselayer24.norm1.num_batches_tracked", "features.denseblock3.denselayer24.norm2.num_batches_tracked", "features.transition3.norm.num_batches_tracked", "features.denseblock4.denselayer1.norm1.num_batches_tracked", "features.denseblock4.denselayer1.norm2.num_batches_tracked", "features.denseblock4.denselayer2.norm1.num_batches_tracked", "features.denseblock4.denselayer2.norm2.num_batches_tracked", "features.denseblock4.denselayer3.norm1.num_batches_tracked", "features.denseblock4.denselayer3.norm2.num_batches_tracked", "features.denseblock4.denselayer4.norm1.num_batches_tracked", "features.denseblock4.denselayer4.norm2.num_batches_tracked", "features.denseblock4.denselayer5.norm1.num_batches_tracked", "features.denseblock4.denselayer5.norm2.num_batches_tracked", "features.denseblock4.denselayer6.norm1.num_batches_tracked", "features.denseblock4.denselayer6.norm2.num_batches_tracked", "features.denseblock4.denselayer7.norm1.num_batches_tracked", "features.denseblock4.denselayer7.norm2.num_batches_tracked", "features.denseblock4.denselayer8.norm1.num_batches_tracked", "features.denseblock4.denselayer8.norm2.num_batches_tracked", "features.denseblock4.denselayer9.norm1.num_batches_tracked", "features.denseblock4.denselayer9.norm2.num_batches_tracked", "features.denseblock4.denselayer10.norm1.num_batches_tracked", "features.denseblock4.denselayer10.norm2.num_batches_tracked", "features.denseblock4.denselayer11.norm1.num_batches_tracked", "features.denseblock4.denselayer11.norm2.num_batches_tracked", "features.denseblock4.denselayer12.norm1.num_batches_tracked", "features.denseblock4.denselayer12.norm2.num_batches_tracked", "features.denseblock4.denselayer13.norm1.num_batches_tracked", "features.denseblock4.denselayer13.norm2.num_batches_tracked", "features.denseblock4.denselayer14.norm1.num_batches_tracked", "features.denseblock4.denselayer14.norm2.num_batches_tracked", "features.denseblock4.denselayer15.norm1.num_batches_tracked", "features.denseblock4.denselayer15.norm2.num_batches_tracked", "features.denseblock4.denselayer16.norm1.num_batches_tracked", "features.denseblock4.denselayer16.norm2.num_batches_tracked", "features.norm5.num_batches_tracked".