Hi, can anyone review my multiclass classification GCN model below? I have a dataset of 4 classes with a total of 1800 samples. My model gives raw logits. I am not able to converge the model as validation loss stays above 1 even after 1000 epochs. I doubt if I am using the argmax function correctly as I have targets or labels in integer form e.g. 0, 1, 2, 3 NOT one-hot encoded.
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True)
def train(data_loader, is_validation=False):
model.train() if not is_validation else model.eval()
correct = 0
total = 0
loss_total = 0
for batch in data_loader:
optimizer.zero_grad()
pred, embedding = model(batch.x.float(), batch.edge_index.long(), batch.batch.long())
target = batch.y.long()
target = target.view(-1)
loss = loss_fn(pred, target)
if not is_validation:
loss.backward()
optimizer.step()
predicted = torch.argmax(pred, dim=1)
total += target.size(0)
correct += (predicted == target).sum().item()
loss_total += loss.item()
accuracy = correct / total
return loss_total / len(data_loader), accuracy
def validate(data_loader):
return train(data_loader, is_validation=True)
print("Starting training...")
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
for epoch in range(1001):
train_loss, train_accuracy = train(train_loader)
val_loss, val_accuracy = validate(val_loader)
train_losses.append(train_loss)
train_accuracies.append(train_accuracy)
val_losses.append(val_loss)
val_accuracies.append(val_accuracy)
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_epoch = epoch
#Saving the model weights
torch.save(model.state_dict(), 'best_model_weights.pth')
if epoch % 100 == 0:
print(f"Epoch: {epoch}, Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}")
scheduler.step(val_loss)
If you have any advice to improve my model. Please let me know. Thank you!!!