Abnormal Train/Validation plots

Hello, I am trying to fine tune a GPT-2 model from Hugging Face on my data. I have a collection of 11984 recipes (a lot of them are similar to eachother) and below is my training loop in which I gather loss data to plot them later on. My batch size is 2 and im training for 3 epochs.

I have also done a Train/Validation/Test split of 80-20 for train-test and 90-10 for training-validation which leaves me with: 2397 test samples, 8628 training samples, 959 validation samples.

This is my train loop:

training_stats = []
print("Currently using device type: ", device)

model = model.to(device)

## My loss ##
all_losses = []
all_val_losses = []

for epoch_i in range(0, epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')
    
    ## My loss##
    epoch_losses = []
    epoch_val_losses = []
    
    losses = []
    total_train_loss = 0
    model.train()

    for step, batch in enumerate(train_dataloader):

        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        model.zero_grad()        

        outputs = model(b_input_ids, labels=b_labels, attention_mask=b_masks, token_type_ids=None)

        loss = outputs[0]  

        batch_loss = loss.item()
        total_train_loss += batch_loss
        losses.append(batch_loss)
        
        ##My loss##
        epoch_losses.append(batch_loss)
        
        # Get sample every x batches.
        if step % sample_every == 0 and not step == 0:
            print('Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.'.format(step, len(train_dataloader), batch_loss))

        loss.backward()
        optimizer.step()
        scheduler.step()

        if step % save_every == 0:
            model.save_pretrained(save_file)

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)       
    
    # Calculate perplexity.
    losses = torch.tensor(losses)
    train_perplexity = math.exp(torch.mean(losses))
    
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Perplexity: {0:.2f}".format(train_perplexity))   

    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    model.eval()

    losses = []
    total_eval_loss = 0
    nb_eval_steps = 0

    # Evaluate data for one epoch
    for batch in validation_dataloader:
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)
        
        with torch.no_grad():        
            outputs  = model(b_input_ids, attention_mask=b_masks,labels=b_labels)
            loss = outputs[0]  
            
        batch_loss = loss.item()
        losses.append(batch_loss)
        total_eval_loss += batch_loss
        
        ## My loss ##
        epoch_val_losses.append(batch_loss)

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    
    # Calculate perplexity.
    losses = torch.tensor(losses)
    val_perplexity = math.exp(torch.mean(losses))

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation perplexity: {0:.2f}".format(val_perplexity))        

    # Record all statistics from this epoch.
    training_stats.append({
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Perplexity': train_perplexity,
            'Valid. Perplexity': val_perplexity})
    
    ## My loss ##
    all_losses.append(epoch_losses)
    all_val_losses.append(epoch_val_losses)

print("")
print("Training complete!")
model.save_pretrained(save_file)

However, when I later plot the losses as such (essentially im plotting the loss of the 3rd epoch):

import matplotlib.pyplot as plt

plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(all_losses[2],label="train")
plt.plot(all_val_losses[2],label="val")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

I get the following “abnormal” plots.

From my understanding, most of the jumping around is caused by the small batch size Im using and I expected that however I assume there would be a greater curve to show the training. My final test loss was 0.02.

Am I doing something wrong or is it just the “bad quality” of the data that is causing this ? Is something like that ok to move on with as long as the model works ok for what im trying to do ?

Thanks in advance and sorry for the lengthy post!