Calculating epoch loss incorrectly?


My epoch loss curve for my multi-class UNET segmentation model looks really off. I get this weird periodic behavior and the loss doesn’t seem to stabilize even with 500 epochs.


Below is my training function:

def train_fn(loader, model, optimizer, loss_fn, scaler, loss_values):
    loop = tqdm(loader)
    running_loss = 0.0
    for batch_idx, (data, targets, _) in enumerate(loop):
        data =
        targets = targets.long().unsqueeze(1).to(device=DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data) #should be (N, C) where C = num of classes
            loss = loss_fn(predictions, targets.squeeze(1))
        # backward
        running_loss =+ loss.item() * data.size(0)
        # update tqdm loop

which gets called for each epoch in the following snippet from main():

    scaler = torch.cuda.amp.GradScaler()
    loss_values = []
    start = time.time()
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler, loss_values)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)
    end = time.time()

Any idea why this might be the case? Thank you!

In this line of code you are scaling the current batch loss with the batch size and accumulate it:

running_loss =+ loss.item() * data.size(0)

while you are normalizing it with the batch size instead of the Dataset length here:


My guess is that the last batch size might be smaller and would thus increase the loss in the last step.

Thank you for your response! That makes sense to me overall but I am confused about some details.

Also I forgot to mention that my batch size is 1…

Did you mean that I am normalizing it with the Dataset length instead of the batch size? So should I have the following instead?


No, data.size(0) would return the same value as len(data):

data = torch.randn(6, 4)
> 6
> 6

Since you are scaling the loss via:

running_loss += current_batch_loss * nb_samples_in_batch

you would have to divide by the number of samples in the entire dataset (len(loader.dataset)).

Thank you very much! That makes sense. I corrected my running loss and now I am training my model with 300 epochs. I will let you know if it was indeed the initial running loss calculation that was causing the graph to be unstable.

Unfortunately after correcting that, I still get the unstable plot even with a batch size of 1.