I am making a model to classify images on a subset of the EMNIST letters dataset. I have 26 classes. I have tried modifying the learning rate and the size of the dataset, but crossentropyloss is still converging around 3.25, and the probabilities for each class are all around 1/26, which essentially means that the model is guessing. I have also tried adding another convolution-pooling block, but I am still getting the same performance.
Data transform:
image_transform = v2.Compose([
v2.Resize(28),
v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), # to tensor
v2.RandomHorizontalFlip(p=1),#100% probability
v2.RandomRotation(degrees=(90,90)), #flip 90 degrees
v2.Normalize((0.1736,),(0.3248,))
])
This is my model:
self.conv1 = nn.Conv2d(1,256,5,padding='same')
self.le_relu1 = nn.LeakyReLU()
self.mpool1 = nn.MaxPool2d(2)
self.bn1 = nn.BatchNorm2d(256)
#Convolution-Pooling block 2:
self.conv2 = nn.Conv2d(256,128,3,padding='same')
self.le_relu2 = nn.LeakyReLU()
self.mpool2 = nn.MaxPool2d(2)
#output block
self.flatten = nn.Flatten()
flattened_dim = int((mpool2_HW**2)*conv2) #flattened dims of height and width
self.fc1 = nn.Linear(in_features=flattened_dim,out_features=128)
self.le_relu3 = nn.LeakyReLU()
self.dropout1 = nn.Dropout()
self.fco = nn.Linear(in_features=128,out_features=26) #26 categories
This is my training loop:
model.train(True) #training mode
running_loss = 0
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(),lr=1e-3,momentum=0.9) #have tried lr 1e-4 aswell, and AdamW
for i, (X, y) in enumerate(train_dataloader):
optimizer.zero_grad()
X = X.to(device=device)
y = y.to(device=device)
X = X.requires_grad_(True)
preds = model(X) # make predictions
loss = loss_fn(preds,y) #calculate loss with loss function
current_loss = loss.item()
loss.backward()
optimizer.step() #update params
#accuracy:
pred_probabilities = torch.softmax(preds,1)
most_probable = torch.argmax(pred_probabilities,1) #select highest probability of each class
num_correct = (most_probable == y).sum().item()
accuracy = num_correct/(y.size(0))
The loss looks like this, reported each batch: