Model.train() loss and model.eval() loss do not match

I am performing an MNIST classification. I have the following training/validation loop:

for epoch in range(epochs):
    t0 = time.time()
    model.train()

    running_loss = 0.
    running_acc = 0.

    for i, (image, label) in enumerate(train_loader):
        optimizer.zero_grad()
        image = image.to(device)
        label = label.long()
        label = label.to(device)
        y = model(image)
        loss = loss_function(y, label)
        loss.backward()
        optimizer.step()

        running_loss += label.shape[0] * loss.item()
        _, prediction = torch.max(y, 1)
        total = label.shape[0]
        correct = (prediction == label).sum().item()
        running_acc += correct/total * 100
        del image, label, y, loss

    print(f"epoch {epoch} | time (sec) : {time.time() - t0:.2f} | t_acc : {(running_acc / len(train_loader)):.2f} | t_loss : {(running_loss / len(train_loader.dataset)):.2f}", end=" | ")

    total = 0
    correct = 0

    with torch.no_grad():
        model.eval()
        running_loss = 0.
        for i, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.long()
            label = label.to(device)
            y = model(image)
            loss = loss_function(y, label)
            running_loss += label.shape[0] * loss.item()
            _, prediction = torch.max(y, 1)
            total += label.shape[0]
            correct += (prediction == label).sum().item()

    print(f"v_acc : {(correct/total * 100):.2f} | v_loss : {(running_loss / len(train_loader.dataset)):2f}")

Note that the second half of the loop (beginning with with torch.no_grad()) would usually iterate over a validation_loader. However, suspicious that my model was not training properly, I replaced validation_loader with train_loader. Thus, I should be seeing more or less equal values for t_loss and v_loss and t_acc and v_acc, respectively. I understand that there will be slight discrepancies since I’m computing t_loss and t_acc as a mean across batches, whereas v_loss and v_acc are computed over the entire training set. Nevertheless, looking at the output below, the discrepancies seem too large:

epoch 0 | time (sec) : 29.59 | t_acc : 21.94 | t_loss : 2.10 | v_acc : 20.28 | v_loss : 2.236833
epoch 1 | time (sec) : 29.61 | t_acc : 23.26 | t_loss : 2.06 | v_acc : 21.78 | v_loss : 2.374591
epoch 2 | time (sec) : 29.62 | t_acc : 28.54 | t_loss : 1.88 | v_acc : 28.54 | v_loss : 1.955524
epoch 3 | time (sec) : 30.20 | t_acc : 40.07 | t_loss : 1.54 | v_acc : 27.01 | v_loss : 2.046747
epoch 4 | time (sec) : 29.40 | t_acc : 44.82 | t_loss : 1.46 | v_acc : 30.15 | v_loss : 1.866847
epoch 5 | time (sec) : 29.38 | t_acc : 55.17 | t_loss : 1.23 | v_acc : 40.11 | v_loss : 1.924560
epoch 6 | time (sec) : 29.70 | t_acc : 77.41 | t_loss : 0.70 | v_acc : 79.59 | v_loss : 0.626660
epoch 7 | time (sec) : 29.65 | t_acc : 84.70 | t_loss : 0.48 | v_acc : 69.35 | v_loss : 0.946439
epoch 8 | time (sec) : 29.93 | t_acc : 87.11 | t_loss : 0.41 | v_acc : 86.33 | v_loss : 0.430216
epoch 9 | time (sec) : 29.97 | t_acc : 88.47 | t_loss : 0.37 | v_acc : 65.38 | v_loss : 1.158408
epoch 10 | time (sec) : 29.90 | t_acc : 89.18 | t_loss : 0.34 | v_acc : 71.23 | v_loss : 0.953730
epoch 11 | time (sec) : 29.63 | t_acc : 90.37 | t_loss : 0.30 | v_acc : 75.13 | v_loss : 0.789689
epoch 12 | time (sec) : 29.62 | t_acc : 90.97 | t_loss : 0.29 | v_acc : 62.28 | v_loss : 1.321358

I have a single nn.BatchNorm1d in my fully connected block, and since this is the only layer that is affected by model.eval(), I was thinking that perhaps this was the culprit. Still, I find it hard to believe that this alone is causing the above-mentioned discrepancies. Why am I experiencing this behavior?

I don’t fully understand your code as it seems your train_loader might have a single batch only?
Both print statements are inside the DataLoader loop, so I would expect to see nb_batches outputs for the training loss and nb_batches * nb_batches for the validation accuracy in each epoch. However, only a single output is given per epoch which points towards batch_size=len(dataset).
You are also not calling model.train() inside the training loop, but at the beginning of the epoch loop if your train_loader has multiple batches, only the first one will be used in train() mode.

@ptrblck I apologize, the indentation of the first print statement was off. I’ve since edited the change. I will make the change to model.train() and update you.

Is the same true for model.eval()? Should it be moved into the validation loop?

Thanks for the update. Your edited code now has a syntax error as the first print statement uses an invalid indentation:


epochs = 10
for epoch in range(epochs):
    # datloader loop
    for i in range(10):
        # model training
        a = 1
        
    print(f"epoch {epoch} | time (sec) : {time.time() - t0:.2f} | t_acc : {(running_acc / len(train_loader)):.2f} | t_loss : {(running_loss / len(train_loader.dataset)):.2f}", end=" | ")

        total = 0
        correct = 0

Output:

    total = 0
    ^
IndentationError: unexpected indent

Sorry! Let’s try this one more time. I’ve re-edited.

@ptrblck Here’s my output after moving model.train() into the training loop:

epoch 0 | time (sec) : 31.82 | t_acc : 22.66 | t_loss : 2.08 | v_acc : 14.39 | v_loss : 3.87
epoch 1 | time (sec) : 30.75 | t_acc : 23.25 | t_loss : 2.05 | v_acc : 16.36 | v_loss : 3.30
epoch 2 | time (sec) : 30.79 | t_acc : 23.64 | t_loss : 2.04 | v_acc : 11.58 | v_loss : 3.09
epoch 3 | time (sec) : 30.77 | t_acc : 41.57 | t_loss : 1.52 | v_acc : 51.23 | v_loss : 1.26
epoch 4 | time (sec) : 30.86 | t_acc : 72.73 | t_loss : 0.83 | v_acc : 59.36 | v_loss : 1.56
epoch 5 | time (sec) : 30.85 | t_acc : 86.40 | t_loss : 0.44 | v_acc : 80.78 | v_loss : 0.59
epoch 6 | time (sec) : 30.81 | t_acc : 90.07 | t_loss : 0.32 | v_acc : 38.93 | v_loss : 2.29
epoch 7 | time (sec) : 30.80 | t_acc : 91.60 | t_loss : 0.27 | v_acc : 69.93 | v_loss : 0.93
epoch 8 | time (sec) : 30.81 | t_acc : 92.37 | t_loss : 0.24 | v_acc : 78.57 | v_loss : 0.64
epoch 9 | time (sec) : 30.88 | t_acc : 93.13 | t_loss : 0.22 | v_acc : 69.25 | v_loss : 1.04

I still seem to have the same issue.

Could you please once again post what the code looks like after all the edits?