I get poor results when I train a CNN model.
I am using GTSRB datasets GTSRB_Final_Training_Images.zip and GTSRB_Final_Test_Images.zip.
My code:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)
model_cnn = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1, stride=2), nn.ReLU(),
nn.Flatten(),
nn.Linear(4096, 100), nn.ReLU(),
nn.Linear(100, 44)).to(device)
data_transforms = transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor(),
#transforms.Normalize((0.5,), (1.0,))
])
train_data_path = "GTSRB/Train"
test_data_path = "GTSRB/Test"
train_data = datasets.ImageFolder(root = train_data_path,transform = data_transforms)
test_data = datasets.ImageFolder(root = test_data_path, transform = data_transforms)
train_loader = DataLoader(train_data, batch_size = 100, shuffle=True)
test_loader = DataLoader(test_data, batch_size = 100, shuffle=False)
def epoch(loader, model, opt=None):
total_loss, total_err = 0.,0.
for X,y in loader:
X,y = X.to(device), y.to(device)
yp = model(X)
loss = nn.CrossEntropyLoss()(yp,y)
if opt:
opt.zero_grad()
loss.backward()
opt.step()
total_err += (yp.max(dim=1)[1] != y).sum().item()
total_loss += loss.item() * X.shape[0]
return total_err / len(loader.dataset), total_loss / len(loader.dataset)
opt = optim.SGD(model_cnn.parameters(), lr=1e-1)
for t in range(20):
train_err, train_loss = epoch(train_loader, model_cnn, opt)
test_err, test_loss = epoch(test_loader, model_cnn)
if t == 4:
for param_group in opt.param_groups:
param_group["lr"] = 1e-2
print(*("{:.6f}".format(i) for i in (train_err, train_loss, test_err, test_loss)), sep="\t")
The output:
0.518248 0.851532 1.000000 11.033573
0.514448 0.821411 0.999921 12.305183
0.515213 0.801951 1.000000 12.332444
0.519358 0.787077 1.000000 12.669552
0.511706 0.775102 0.997783 12.443266
0.509730 0.732133 0.999842 13.915797
0.508263 0.724747 0.999842 14.257156
0.510163 0.722937 1.000000 14.508054
0.510980 0.722053 0.999921 14.173411
0.511617 0.721190 0.999921 14.703455
0.510368 0.720383 1.000000 14.771137
0.509437 0.719646 0.999683 14.247469
0.512255 0.719317 1.000000 14.755536
0.513505 0.718741 0.999525 14.515192
0.511222 0.718012 1.000000 14.858491
0.510100 0.717544 0.999604 14.981595
0.509335 0.717167 0.999842 14.997303
0.509704 0.716572 0.999842 14.829027
0.511821 0.716451 0.999921 15.136488
0.512574 0.716203 0.999921 15.007975
What can I change to improve the result?
/Oualid