Validation loss extremely large but good results

I’m training a simple self-attention model and I’m obtaining some good results on the validation set (in terms of accuracy, MCC, recall and precision). I’ve done this doing a train/test split several times. The only problem is that the validation loss is extremely large compared to the training loss. I’m attaching an example but they more or less all look the same:

If I train it for like 500 epochs, the validation loss keep decreasing nicely (while the training loss oscillates more) but it’s always much larger. Has anyone seen something similar before?

I’m also attaching the training and validation loop:

training_loss = []
val_loss = []

for epoch in range(1, num_epochs+1):
    #print(f'EPOCH: {epoch}...')


    avg_loss = 0.0
    for idx, batch in enumerate(train_loader):
        smiles, labels = batch[0].to(device), batch[1].to(device)

        # Fit

        out = model(smiles)

        loss = criterion(out, labels)
        avg_loss =+ loss.item() * smiles.size(0)


    training_loss.append(avg_loss / len(train_loader))

    # Validation
    with torch.no_grad():

        avg_loss = 0.0
        y_pred = []
        y_val = []
        for idx, batch in enumerate(val_loader):
            smiles, labels = batch[0].to(device), batch[1].to(device)

            out = model(smiles)
            y_pred.extend(list(torch.argmax(out, dim=1).detach().cpu().numpy()))

            loss = criterion(out, labels)
            avg_loss =+ loss.item() * smiles.size(0)

    val_loss.append(avg_loss / len(val_loader))

I believe I found the error: I did =+ rather than += when computing the average loss. Now I get:

Nonetheless, they are now both very large! The scores are always the same, obviously. So now the question is: does it matter if the loss is high (I’m using cross-entropy)?

It seems you are multiplying by the batch size, which would thus accumulate the loss of all samples in avg_loss.
If that’s the case, you should divide by the number of samples to get the average loss, bot the length of the DataLoader, which will return the number of batches.