Greetings.
I started using the pre-trained models in Torchvision such as Resnet18. As a first test, I wanted to adjust the pretrained model to CIFAR10. I changed the dimensions of the last layer, and froze all other layers. I decrease my learning rate every 10 epochs and train for 50 epochs. Yet, my accuracies seem to be much worse than what they should be. I only reach around 40% on the training data, even though I have read with only 5 epochs of retraining, one can easily reach 70% accuracies.
I would appreciate if someone could take a glance at my code to see whether everything looks correct.
#%%
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import time
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
# method to get accuracies and loss to a given model on a dataloader
def eval_model(model, dataloader, criterion):
accu = 0.0
loss = 0.0
model.eval()
with torch.no_grad():
for (batchidx, (features, targets)) in enumerate(dataloader):
output = model(features)
loss += criterion(output, targets, reduction="sum")
_, predicted = torch.max(output,1)
accu += torch.sum(predicted == targets)
if dataloader.drop_last == True:
accu /= np.float(dataloader.batch_size * len(dataloader))
loss /= np.float(dataloader.batch_size * len(dataloader))
else:
accu /= len(dataloader.dataset)
loss /= len(dataloader.dataset)
return (loss, accu)
#%% Training parameters
epochs = 50
meas_freq = 5 # take measurements every 5 epochs
bs = 128
eta = 0.1
mom = 0.9
seed = 0
torch.manual_seed(seed)
#%% Load CIFAR 10
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transfos = transforms.Compose([transforms.ToTensor(), normalize])
train = datasets.CIFAR10(".", train = True, download = True, transform = transfos)
test = datasets.CIFAR10(".", train = False, download = True, transform = transfos)
batches_train = torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True, drop_last=True)
batches_test = torch.utils.data.DataLoader(test, batch_size=bs, shuffle=True, drop_last=True)
#%% Load ResNet and change last layer
model = models.resnet18(pretrained=True)
for params in model.parameters():
params.requires_grad = False
model.fc = torch.nn.Linear(512,10)
#%% Retrain last layer
print("starting training...")
start_time = time.time()
loss_train = np.zeros(epochs+1) # store loss here
accu_train = np.zeros(epochs+1) # store accuracies
loss_test = np.zeros(epochs//meas_freq +1)
accu_test = np.zeros(epochs//meas_freq +1)
(loss_train[0], accu_train[0]) = eval_model(model, batches_train, F.cross_entropy)
(loss_test[0], accu_test[0]) = eval_model(model, batches_test, F.cross_entropy)
t = 1 # counter
optimizer = optim.SGD(model.parameters(), lr=eta, weight_decay=0.0005, momentum=mom)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, last_epoch=-1, verbose=False)
for epoch in range(1,epochs+1):
model.train()
# iterate over batches
for (batchidx, (features, targets)) in enumerate(batches_train):
output = model(features)
optimizer.zero_grad()
loss = F.cross_entropy(output, targets, reduction="sum")
loss.backward()
optimizer.step()
loss_train[epoch] += loss
accu_train[epoch] += torch.sum(torch.max(output,1)[1] == targets)
if epoch % meas_freq==0: # obtain test measurements
print("epoch: {}".format(epoch))
(loss_test[t], accu_test[t]) = eval_model(model, batches_test, F.cross_entropy)
t += 1
scheduler.step()
loss_train /= len(batches_train) * bs ## average loss.
accu_train /= len(batches_train) * bs ## average accuracy.
end_time = time.time()
print("Training took {} seconds, i.e {} minutes, with {} seconds per epoch!".format(end_time-start_time, (end_time-start_time)/60, (end_time-start_time)/epochs))
#%% Plot stuff
fig, axs = plt.subplots(1,2) # prepare plotting
axs[0].plot(np.arange(0,epochs+1), loss_train, label="Train") # plot losses
axs[0].legend()
axs[0].set_ylabel("Loss")
axs[0].set_xlabel("Epochs")
axs[1].plot(np.arange(0,epochs+1), accu_train, label="Train")
axs[1].legend()
axs[1].set_ylabel("Accuracy")
axs[1].set_xlabel("Epochs")
fig.suptitle("Train")
plt.show()