Argmax function in Multi-Class Classification

Hi, can anyone review my multiclass classification GCN model below? I have a dataset of 4 classes with a total of 1800 samples. My model gives raw logits. I am not able to converge the model as validation loss stays above 1 even after 1000 epochs. I doubt if I am using the argmax function correctly as I have targets or labels in integer form e.g. 0, 1, 2, 3 NOT one-hot encoded.

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True)

def train(data_loader, is_validation=False):
    model.train() if not is_validation else model.eval()
    correct = 0
    total = 0
    loss_total = 0

    for batch in data_loader:
        optimizer.zero_grad()
        pred, embedding = model(batch.x.float(), batch.edge_index.long(), batch.batch.long())  
        target = batch.y.long()

        target = target.view(-1) 

        loss = loss_fn(pred, target)

        if not is_validation:
            loss.backward()
            optimizer.step()

        predicted = torch.argmax(pred, dim=1)

        total += target.size(0)
        correct += (predicted == target).sum().item()
        loss_total += loss.item()

    accuracy = correct / total
    return loss_total / len(data_loader), accuracy

def validate(data_loader):
    return train(data_loader, is_validation=True)

print("Starting training...")
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(1001):
    train_loss, train_accuracy = train(train_loader)
    val_loss, val_accuracy = validate(val_loader)

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

     if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch
        #Saving the model weights
       torch.save(model.state_dict(), 'best_model_weights.pth')

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}")

    scheduler.step(val_loss)

If you have any advice to improve my model. Please let me know. Thank you!!!

The usage of predicted = torch.argmax(pred, dim=1) looks correct assuming pred is the model output in the shape [batch_size, nb_classes, *].

1 Like

Yes, the model works as it does not throw any error. Thank you for the clarification, maybe I don’t have enough data for the learning.