Hello, I am trying to fine tune a GPT-2 model from Hugging Face on my data. I have a collection of 11984 recipes (a lot of them are similar to eachother) and below is my training loop in which I gather loss data to plot them later on. My batch size is 2 and im training for 3 epochs.
I have also done a Train/Validation/Test split of 80-20 for train-test and 90-10 for training-validation which leaves me with: 2397 test samples, 8628 training samples, 959 validation samples.
This is my train loop:
training_stats = []
print("Currently using device type: ", device)
model = model.to(device)
## My loss ##
all_losses = []
all_val_losses = []
for epoch_i in range(0, epochs):
# ========================================
# Training
# ========================================
print("")
print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
print('Training...')
## My loss##
epoch_losses = []
epoch_val_losses = []
losses = []
total_train_loss = 0
model.train()
for step, batch in enumerate(train_dataloader):
b_input_ids = batch[0].to(device)
b_labels = batch[0].to(device)
b_masks = batch[1].to(device)
model.zero_grad()
outputs = model(b_input_ids, labels=b_labels, attention_mask=b_masks, token_type_ids=None)
loss = outputs[0]
batch_loss = loss.item()
total_train_loss += batch_loss
losses.append(batch_loss)
##My loss##
epoch_losses.append(batch_loss)
# Get sample every x batches.
if step % sample_every == 0 and not step == 0:
print('Batch {:>5,} of {:>5,}. Loss: {:>5,}.'.format(step, len(train_dataloader), batch_loss))
loss.backward()
optimizer.step()
scheduler.step()
if step % save_every == 0:
model.save_pretrained(save_file)
# Calculate the average loss over all of the batches.
avg_train_loss = total_train_loss / len(train_dataloader)
# Calculate perplexity.
losses = torch.tensor(losses)
train_perplexity = math.exp(torch.mean(losses))
print("")
print(" Average training loss: {0:.2f}".format(avg_train_loss))
print(" Perplexity: {0:.2f}".format(train_perplexity))
# ========================================
# Validation
# ========================================
print("")
print("Running Validation...")
model.eval()
losses = []
total_eval_loss = 0
nb_eval_steps = 0
# Evaluate data for one epoch
for batch in validation_dataloader:
b_input_ids = batch[0].to(device)
b_labels = batch[0].to(device)
b_masks = batch[1].to(device)
with torch.no_grad():
outputs = model(b_input_ids, attention_mask=b_masks,labels=b_labels)
loss = outputs[0]
batch_loss = loss.item()
losses.append(batch_loss)
total_eval_loss += batch_loss
## My loss ##
epoch_val_losses.append(batch_loss)
avg_val_loss = total_eval_loss / len(validation_dataloader)
# Calculate perplexity.
losses = torch.tensor(losses)
val_perplexity = math.exp(torch.mean(losses))
print(" Validation Loss: {0:.2f}".format(avg_val_loss))
print(" Validation perplexity: {0:.2f}".format(val_perplexity))
# Record all statistics from this epoch.
training_stats.append({
'epoch': epoch_i + 1,
'Training Loss': avg_train_loss,
'Valid. Loss': avg_val_loss,
'Training Perplexity': train_perplexity,
'Valid. Perplexity': val_perplexity})
## My loss ##
all_losses.append(epoch_losses)
all_val_losses.append(epoch_val_losses)
print("")
print("Training complete!")
model.save_pretrained(save_file)
However, when I later plot the losses as such (essentially im plotting the loss of the 3rd epoch):
import matplotlib.pyplot as plt
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(all_losses[2],label="train")
plt.plot(all_val_losses[2],label="val")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
I get the following “abnormal” plots.
From my understanding, most of the jumping around is caused by the small batch size Im using and I expected that however I assume there would be a greater curve to show the training. My final test loss was 0.02.
Am I doing something wrong or is it just the “bad quality” of the data that is causing this ? Is something like that ok to move on with as long as the model works ok for what im trying to do ?
Thanks in advance and sorry for the lengthy post!