import torch
import torchvision
import torchvision.transforms as transforms
import os
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset=torchvision.datasets.CIFAR10(root=’./data’,train=True,download=True,transform=transform_train)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset=torchvision.datasets.CIFAR10(root=’./data’,train=False,download=True,transform=transform_test)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True,num_workers=2)
classes=(‘plane’,‘car’,‘bird’,‘cat’,‘deer’,‘dog’,‘frog’,‘horse’,‘ship’,‘truck’)
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
device=torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
import torchvision.models as models
vggnet=models.vgg16_bn(pretrained=True)
import matplotlib.pyplot as plt
classifier=list(vggnet.classifier.children())[:-6]
features=list(vggnet.features.children())
features.extend(nn.Sequential(nn.AvgPool2d(1, stride=1, padding=0, ceil_mode=False, count_include_pad=True)))
avgpool=list(vggnet.avgpool.children())[:-1]
vggnet.features=nn.Sequential(*list(features))
vggnet.classifier=nn.Sequential(*list(classifier))
vggnet.classifier[0]=nn.Linear(512,10)
vggnet.avgpool=nn.Sequential(*list(avgpool))
class CifarVggnet(nn.Module):
def init(self,vggnet):
super(CifarVggnet,self).init()
self.vggnet=vggnet
def forward(self,x):
return self.vggnet(x)
model=CifarVggnet(vggnet)
print(model)
PATH=’./checkpoint/vggCifar10TEST.pth’
batch_size=128
num_epochs=1
model.to(device)
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.0001)
torch.cuda.empty_cache()
globaliter=-1
best_accuracy=0
trainLOSS=[]
testACC=[]
testLOSS=[]
avg_train=0
avg_test_loss=0
for epoch in range(num_epochs):
globaliter+=1
running_loss=0
total=0
avg_train=0
correct_classified=0
model.train()
for i,data in enumerate(trainloader):
inputs,labels=data
inputs,labels=inputs.to(device),labels.to(device)
optimizer.zero_grad()
outputs=model(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
predicted=torch.argmax(outputs,1)
total+=labels.size(0)
correct_classified+=(predicted==labels).sum().item()
running_loss+=loss.item()
if i%200==199:
print('Epoch:[%d, %5d] ’ % (epoch+1,i+1))
train_acc=100*(correct_classified/total)
avg_train=float(running_loss)/total
trainLOSS.append(avg_train)
print('Train Accuracy:%.3f'%(train_acc))
print('Train Average Loss:%.4f'%(avg_train))
c=0
total=0
globaliter=-1
l=0
j=-1
model.eval()
with torch.no_grad():
for data in testloader:
globaliter+=1
j=j+1
inputs,labels=data
inputs,labels=inputs.to(device),labels.to(device)
optimizer.zero_grad()
outputs=model(inputs)
loss=criterion(outputs,labels)
l=l+loss.item()
predicted=torch.argmax(outputs,1)
total+=labels.size(0)
c+=(predicted==labels).sum().item()
test_acc=(100*c/total)
if(test_acc>best_accuracy):
best_accuracy=test_acc
print('Saving...')
state={'net':model,'acc':best_accuracy,'epoch':epoch}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
# torch.save(state,PATH)
# torch.save(model,PATH)
torch.save(model.state_dict(),PATH)
print('Accuracy of the network on test images:%.3f %%' % test_acc)
testACC.append(test_acc)
avg_test_loss=float(l)/total
testLOSS.append(avg_test_loss)
fig1 = plt.figure(1)
plt.plot(range(epoch+1),trainLOSS,'r-',label='train loss')
plt.plot(range(epoch+1),testLOSS,'g-',label='test loss')
if epoch==0:
plt.legend(loc='upper left')
plt.xlabel('Epochs')
plt.ylabel('Loss')
fig2 = plt.figure(2)
plt.plot(range(epoch+1),testACC,'g-',label='test')
if epoch==0:
plt.legend(loc='upper left')
plt.xlabel('Epochs')
plt.ylabel('Testing Accuracy')
print(‘Best Accuracy of the network on test images:%.3f %%’% best_accuracy)
fig1.savefig(‘trainloss_vs_epoch1.png’)
fig2.savefig(‘testacc_vs_epoch1.png’)
This is the code where I’m actually training and saving. For checking purpose I did set num_epoch=1.