Full code
data = download_url("https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz",".")
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
tar.extractall(path='./data')
if os.path.exists("/content/data/cifar10/validate") is False:
os.makedirs("/content/data/cifar10/validate")
os.makedirs("/content/data/cifar10/validate/airplane")
os.makedirs("/content/data/cifar10/validate/automobile")
os.makedirs("/content/data/cifar10/validate/bird")
os.makedirs("/content/data/cifar10/validate/cat")
os.makedirs("/content/data/cifar10/validate/deer")
os.makedirs("/content/data/cifar10/validate/dog")
os.makedirs("/content/data/cifar10/validate/frog")
os.makedirs("/content/data/cifar10/validate/horse")
os.makedirs("/content/data/cifar10/validate/ship")
os.makedirs("/content/data/cifar10/validate/truck")
for i in sample(glob.glob("/content/data/cifar10/train/airplane/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/airplane")
for i in sample(glob.glob("/content/data/cifar10/train/automobile/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/automobile")
for i in sample(glob.glob("/content/data/cifar10/train/bird/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/bird")
for i in sample(glob.glob("/content/data/cifar10/train/cat/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/cat")
for i in sample(glob.glob("/content/data/cifar10/train/deer/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/deer")
for i in sample(glob.glob("/content/data/cifar10/train/dog/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/dog")
for i in sample(glob.glob("/content/data/cifar10/train/frog/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/frog")
for i in sample(glob.glob("/content/data/cifar10/train/horse/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/horse")
for i in sample(glob.glob("/content/data/cifar10/train/ship/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/ship")
for i in sample(glob.glob("/content/data/cifar10/train/truck/*.png"),500):
shutil.move(i,"/content/data/cifar10/validate/truck")
from torch.utils.data.dataset import random_split
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
train_ds = ImageFolder("/content/data/cifar10/train",transform)
val_ds = ImageFolder("/content/data/cifar10/validate", transform)
train_dl = DataLoader(train_ds, batch_size = 32, shuffle=True, num_workers=2, pin_memory=True)
val_dl = DataLoader(val_ds,batch_size = 32,shuffle = False, num_workers=2, pin_memory=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model, train_dl , val_dl, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
running_training_loss = 0.0
running_validation_loss = 0.0
running_corrects_training = 0
running_corrects_validation = 0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
#training loop
for images,labels in train_dl:
with torch.set_grad_enabled(True):
images,labels = images.to(device), labels.to(device)
model.train(True):
outputs = model(images)
_, preds_train = torch.max(outputs, 1)
loss_train = criterion(outputs, labels)
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
#stats
running_training_loss += loss_train.item()
running_corrects_training += torch.sum(preds_train == labels.data)
#scheduler step
scheduler.step()
#stats
epoch_loss_training = running_training_loss / len(train_dl)
epoch_acc_training = running_corrects_training.double() / len(train_dl)
#validation loop
for image,label in val_dl:
with torch.set_grad_enabled(False):
image,label = image.to(device), label.to(device)
model.eval():
outputs = model(image)
_, preds_val = torch.max(outputs, 1)
loss_val = criterion(outputs, label)
#stats
running_validation_loss += loss_val.item()
running_corrects_validation += torch.sum(preds_val == labels.data)
epoch_loss_validation = running_validation_loss / len(val_dl)
epoch_acc_validation = running_corrects_validation.double() / len(val_dl)
print("epoch {}, epoch training loss {}, epoch training acc {}, epoch_loss_validation{}, epoch_acc_validation" .format(epoch, epoch_loss_training, epoch_acc_training,
epoch_loss_validation, epoch_acc_validation))
#finding best accuracy
if running_corrects_validation > best_acc:
best_acc = running_corrects_validation
best_model_wts = copy.deepcopy(model.state_dict())
print()
# total time
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model
#downloading pretrained model/finetune
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False
#defining parameters
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 10)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv,train_dl, val_dl, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=25)