Why is my training loss stochastic and my validation loss is okay

Hi, I’m having a haphazard looking training loss, but my validation loss is good. Please any idea on why this might be? below are both losses.

Could you describe how these losses are calculated?
E.g. if you are printing the mean training loss for each batch and the mean epoch validation loss, the different noise levels might be expected.

This is how I compute the losses, note that “compute_first_loss” and “compute_second_loss” are methods

  for epoch in tqdm(range(epochs)):
        for batch, data in enumerate(training_generator):
         
            X, Y = data[0].squeeze(-1).to(device).float(), data[1].squeeze(-1).to(device).float()
            
            output = model(X, Y)
            l1 = compute_first_loss(output[0][:-1])
            l2  = compute_second_loss(output[0][-1])
            loss = l1 + l2

            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step()
            total_loss.append(loss.item())
          

            print(loss.item())


        #validation using current weights

        for val_batch, data in enumerate(validation_generator):

            X, Y = data[0].squeeze(-1).to(device).float(), data[1].squeeze(-1).to(device).float()
            with torch.no_grad():
                output = model(X, Y)
                l1_val = compute_first_loss(output[0][:-1])
                l2_val  = compute_second_loss(output[0][-1])
                loss_val = l1_val + l2_val

                total_val_loss.append(loss_val.item())
               print(loss_val.item())

Are these plots representing the calculated loss and loss_val from the code or how are you plotting these curves?

Yes loss and loss_val is what’s plotted. I’m using wandb to plot the curves. So the code is

wandb.log({"training_loss": loss.item(),  "validation_loss": loss_val.item())}

This is plotted for every epoch. So for every epoch, I plot the last training and validation loss computed