Thanks for your reply here its the code:
The model
resnet_model = models.resnet18(pretrained=True)
for param in resnet_model.parameters():
param.requires_grad = False
num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, len(classes))
resnet_model = resnet_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet_model.fc.parameters(), lr=1e-3, weight_decay=0)
Data
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
}
images = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(images[x], batch_size=128,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
Training
def train():
resnet_model.train(True)
torch.set_grad_enabled(True)
running_loss = 0.0
running_corrects = 0
for data in dataloaders['train']:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = resnet_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / len(dataloaders['train'])
epochs_acc = running_corrects / dataset_sizes['train']
return epoch_loss, epochs_acc
Validation
def evaluate():
resnet_model.train(False)
running_loss = 0
running_corrects = 0
for data in dataloaders['val']:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = resnet_model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
running_loss += loss.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / len(dataloaders['val'])
epochs_acc = running_corrects / dataset_sizes['val']
return epoch_loss, epochs_acc
Main Loop
while True:
epoch_start_time = time.time()
train_loss, train_acc = train()
train_losses.append(train_loss)
train_accuracies.append(train_acc)
print('-' * 73)
print('| End of epoch: {:3d} | Time: {:6.2f}s | Train loss: {:.2f} | Train acc: {:.2f}|'
.format(epoch, (time.time() - epoch_start_time), train_loss, train_acc))
val_loss, val_acc = evaluate()
val_losses.append(val_loss)
val_accuracies.append(val_acc)
print('-' * 73)
print('| End of epoch: {:3d} | Time: {:6.2f}s | Valid loss: {:.2f} | Valid acc: {:.2f}|'
.format(epoch, (time.time() - epoch_start_time), val_loss, val_acc))
if val_loss < best_val_loss:
with open('resnet_model.pt', 'wb') as f:
torch.save(resnet_model, f)
best_val_loss = val_loss
bad_epochs = 0
else:
bad_epochs += 1
if bad_epochs == 10:
break
epoch += 1