This is my save model code:
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchfile
import numpy as np
class VGG16(nn.Module):
"""
Main Class
"""
def __init__(self):
"""
Constructor
"""
super(VGG16, self).__init__()
self.block_size = [2, 2, 3, 3, 3]
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.fc6 = nn.Linear(512 * 7 * 7, 4096)
self.fc7 = nn.Linear(4096, 4096)
self.fc8 = nn.Linear(4096, 2622)
def forward(self, x):
""" Pytorch forward
Args:
x: input image (224x224)
Returns: class logits
"""
x = F.relu(self.conv_1_1(x))
x = F.relu(self.conv_1_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_2_1(x))
x = F.relu(self.conv_2_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_3_1(x))
x = F.relu(self.conv_3_2(x))
x = F.relu(self.conv_3_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_4_1(x))
x = F.relu(self.conv_4_2(x))
x = F.relu(self.conv_4_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_5_1(x))
x = F.relu(self.conv_5_2(x))
x = F.relu(self.conv_5_3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
x = F.dropout(x, 0.5, self.training)
x = F.relu(self.fc7(x))
x = F.dropout(x, 0.5, self.training)
return self.fc8(x)
if __name__ == "__main__":
model = VGG16().double()
torch.save(model.state_dict(), 'vggFaceModel.pth')
This is my load model code:
import torch
from torch.autograd import Variable
import os
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from vggModel import vgg16
from vggModel3 import VGG16
#model = torch.load('model.pth')
#model = torch.load('vgg16-397923af.pth')
#model1 = models.vgg16(pretrained=False)
#model1.load_state_dict(torch.load('vgg16-397923af.pth'))
#preModel = VGG16
preModel = vgg16
#ass = torch.load('vggFaceModel.pth')
preModel.load_state_dict(torch.load('vggFaceModel.pth'))
The wrong text is: load_state_dict() missing 1 required positional argument: ‘state_dict’