I’m new to PyTorch and I’m trying to build a CNN model for the CIFAR-10 dataset. I’ve been working on it for a while, but I’m facing some challenges and would really appreciate some help from the community.
model:
class CNN_model(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
nn.Flatten(),
nn.Linear(128*8*8, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10))
def forward(self, x):
return self.network(x)
model = CNN_model()
training code:
def train(model, loss_fn, optimizer, train_dl, num_epoch, learning_rate):
optim = optimizer(model.parameters(), lr=learning_rate)
trian_loss = 0
for epoch in range(num_epoch):
model.train()
for i, (images, labels) in enumerate(train_dl):
images, labels = images.to(device), labels.to(device)
optim.zero_grad()
logits = model(images)
loss = loss_fn(logits, labels)
loss.backward() # delta_loss/ delta_model_parameters
optim.step()
trian_loss += loss.item()
print(f"Epoch: {epoch+1} | Training loss: {trian_loss:.2f}")
print('Training Finished!')
model.to(device)
learning_rate = 0.001
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam
num_epochs = 10
train(model=model,
loss_fn=loss_func,
optimizer=optim,
train_dl=train_loaders,
num_epoch=num_epochs,
learning_rate=learning_rate)
# Testing loop
num_correct = 0
num_smaples = 0
with torch.inference_mode():
model.eval
for i, (data, labels) in enumerate(test_loaders):
data, labels = data.to(device), labels.to(device)
y_pred = model(data)
loss = loss_func(y_pred, labels)
_, prediction = y_pred.max(1)
num_correct += (prediction == labels).sum().item()
num_smaples += prediction.size(0)
print(
f"{num_correct}/{num_smaples} correctly classified | accuracy : {float(num_correct)/float(num_smaples)*100:.2f}%"
)
output: why training loss is so high??
Files already downloaded and verified
Files already downloaded and verified
Epoch: 1 | Training loss: 1096.02
Epoch: 2 | Training loss: 1789.25
Epoch: 3 | Training loss: 2293.96
Epoch: 4 | Training loss: 2659.38
Epoch: 5 | Training loss: 2897.83
Epoch: 6 | Training loss: 3050.23
Epoch: 7 | Training loss: 3153.03
Epoch: 8 | Training loss: 3231.70
Epoch: 9 | Training loss: 3300.38
Epoch: 10 | Training loss: 3358.28
Training Finished!
7519/10000 correctly classified | accuracy : 75.19%