Hi all,
I have applied transfer learning on a VGG11 model as below. My aim was to train a model to classify CIFAR10 images which has only 10 labels (output class).
First, I loaded a pertained vgg11 model. Than I freezed the conv layers, also updated the last layer’s shape to 10. Than I have trained only the linear layers for 10 epoch. After training, I save my model parameters as vgg.pt file.
vgg = models.vgg11(pretrained=True)
vgg.classifier[6].out_features = 10
# freeze convolution weights
for param in vgg.features.parameters():
param.requires_grad = False
Later on when I wanted to use the model weights, I do sth like:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_cnn = models.vgg11(pretrained=False)
model_cnn.load_state_dict(torch.load("vgg.pt",map_location=device))
The problem is, when I print the model, I see the out layer has 1000 neurons. But I would expect them to be 10!
Can you please help me to correctly load and use the pretrained model?
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(12): ReLU(inplace=True)
(13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(14): ReLU(inplace=True)
(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(19): ReLU(inplace=True)
(20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
After loading the model and before loading state dictionary, should I update the last layer size to 10 like this?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_cnn = models.vgg11(pretrained=False)
vgg.classifier[6].out_features = 10
model_cnn.load_state_dict(torch.load("vgg.pt",map_location=device))
Note:
The complete code to train the model is:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import time
import matplotlib.pyplot as plt
transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_data = datasets.CIFAR10(root='../../CIFARDATA', train=True, download=False, transform=transform)
test_data = datasets.CIFAR10(root='../../CIFARDATA', train=False, download=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(2)
vgg = models.vgg11(pretrained=True)
#print(vgg)
vgg.classifier[6].out_features = 10
# freeze convolution weights
for param in vgg.features.parameters():
param.requires_grad = False
optimizer = optim.SGD(vgg.classifier.parameters(), lr=0.001, momentum=0.9)
# loss function
criterion = nn.CrossEntropyLoss()
# validation function
def validate(model, test_dataloader):
model.eval()
val_running_loss = 0.0
val_running_correct = 0
for int, data in enumerate(test_dataloader):
data, target = data[0].to(device), data[1].to(device)
output = model(data)
loss = criterion(output, target)
val_running_loss += loss.item()
_, preds = torch.max(output.data, 1)
val_running_correct += (preds == target).sum().item()
val_loss = val_running_loss / len(test_dataloader.dataset)
val_accuracy = 100. * val_running_correct / len(test_dataloader.dataset)
return val_loss, val_accuracy
def fit(model, train_dataloader):
model.train()
train_running_loss = 0.0
train_running_correct = 0
for i, data in enumerate(train_dataloader):
data, target = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
train_running_loss += loss.item()
_, preds = torch.max(output.data, 1)
train_running_correct += (preds == target).sum().item()
loss.backward()
optimizer.step()
train_loss = train_running_loss / len(train_dataloader.dataset)
train_accuracy = 100. * train_running_correct / len(train_dataloader.dataset)
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}')
return train_loss, train_accuracy
print("BEGIN TRAINING")
train_loss , train_accuracy = [], []
val_loss , val_accuracy = [], []
start = time.time()
for epoch in range(10):
train_epoch_loss, train_epoch_accuracy = fit(vgg, train_loader)
val_epoch_loss, val_epoch_accuracy = validate(vgg, test_loader)
train_loss.append(train_epoch_loss)
train_accuracy.append(train_epoch_accuracy)
val_loss.append(val_epoch_loss)
val_accuracy.append(val_epoch_accuracy)
print("at epoch ", epoch)
end = time.time()
print((end-start)/60, 'minutes')
torch.save(vgg.state_dict(), "vgg.pt")
print("done")
model_cnn = models.vgg11(pretrained=False)
model_cnn.load_state_dict(torch.load("vgg.pt"))
model_cnn.eval()
total_loss, total_err = 0., 0.
for X, y in test_loader:
X, y = X.to(device), y.to(device)
yp = model_cnn(X)
total_err += (yp.max(dim=1)[1] != y).sum().item()
print("validation error")
print(total_err / len(test_loader.dataset))
plt.figure(figsize=(10, 7))
plt.plot(train_accuracy, color='green', label='train accuracy')
plt.plot(val_accuracy, color='blue', label='validataion accuracy')
plt.legend()
plt.savefig('accuracy.png')
plt.show()