Wondering why my model could be plateauing

I’m training a deep neural network on temporal graph data. Currently, I’m trying to get a feel for how large / complex of a model I should aim for, so I’m trying to overfit to my smallest dataset. Here is my code.

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import sys
import random
import os
import datetime

# Command to run
# python3 my_dnn.py [dataset_basename] [num_layers] [batch_size] [neurons_per_hidden_layer] [data_directory] [num_epochs] [learning_rate] [rng_seed]
# python3 my_dnn.py enron 25 20 256 ./formatted_data 10 0.01 48444316

# Setting the CUBLAS environment variable for deterministic operations
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

# Name of the dataset to train on
DATASET_BASENAME = sys.argv[1]

# Number of features (two endpoints of an edge + insertion time and two numeric features)
INPUT_SIZE = 5  

# Number of stacked rnn layers
NUM_LAYERS = int(sys.argv[2])
BATCH_SIZE = int(sys.argv[3])
OUTPUT_SIZE = 1
HIDDEN_DIM = int(sys.argv[4])
DATA_DIRECTORY = str(sys.argv[5])

# Number of previous timestamps taken into account
SEQ_LENGTH = 2
NUM_EPOCHS = int(sys.argv[6])
LEARNING_RATE = float(sys.argv[7])
LSTM = False

# Try to use gpu if it's there
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
print("GPU is available" if is_cuda else "GPU not available, CPU used")


# We require all algorithms to be deterministic and fix random seed 48444316 for reproducibility
torch.use_deterministic_algorithms(True)
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(int(sys.argv[8]))


class TemporalGraphDataset(Dataset):
    def __init__(self, feature_file, label_file, seq_length):
        self.features = pd.read_csv(feature_file)
        self.labels = pd.read_csv(label_file)
        self.seq_length = seq_length

    def __len__(self):
        return len(self.features) - self.seq_length + 1

    def __getitem__(self, idx):
        start_idx = idx
        end_idx = idx + self.seq_length
        data = self.features.iloc[start_idx:end_idx].to_numpy()
        target = self.labels.iloc[end_idx - 1].to_numpy()
        return torch.tensor(data, dtype=torch.float32), torch.tensor(target, dtype=torch.float32)

features_path = f'{DATA_DIRECTORY}/formatted_{DATASET_BASENAME}/{DATASET_BASENAME}_features.csv'
labels_path = f'{DATA_DIRECTORY}/formatted_{DATASET_BASENAME}/{DATASET_BASENAME}_labels.csv'
dataset = TemporalGraphDataset(features_path, labels_path, SEQ_LENGTH)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 100, 100])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

class DNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super(DNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True) if not LSTM else nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_size)

    def forward(self, x):
        hidden = self.init_hidden(x.size(0))
        out, _ = self.rnn(x) if not LSTM else self.rnn(x, hidden)
        return self.fc(out[:, -1, :])  # Taking the last timestep output

    def init_hidden(self, batch_size):
        return torch.zeros(NUM_LAYERS, batch_size, HIDDEN_DIM, device=device)

model = DNN(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_DIM, NUM_LAYERS).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

# Validation function
def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data, target in dataloader:
            model.zero_grad()
            data, target = data.cuda().float(), target.cuda().float()
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
    return total_loss / len(dataloader)

# Training function with validation after each
def train_model(model, train_loader, val_loader, optimizer, criterion, epochs):
    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for data, target in train_loader:
            model.zero_grad()
            data, target = data.cuda(), target.cuda()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1000)
            optimizer.step()
            total_train_loss += loss.item()
            
        # Store average train loss over all examples
        train_losses.append(total_train_loss / len(train_loader))

        val_loss = validate(model, val_loader, criterion)
        val_losses.append(val_loss)

        print(f'Epoch {epoch+1}, Training Loss: {train_losses[-1]}, Validation Loss: {val_losses[-1]}')

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    today = datetime.datetime.now().strftime("%Y-%m-%d")
    plot_filename = f"{DATASET_BASENAME}_{today}.png"
    plt.savefig(plot_filename)
    print(f"Plot saved as {plot_filename}")
    plt.show()

train_model(model, train_loader, val_loader, optimizer, criterion, NUM_EPOCHS)

Here is the output from the terminal

(env) python3 my_dnn.py enron 25 20 256 ./formatted_data 10 0.01 48444316
GPU is available
Epoch 1, Training Loss: 255494918.86796424, Validation Loss: 178531414.4
Epoch 2, Training Loss: 130291595.62125197, Validation Loss: 80107624.8
Epoch 3, Training Loss: 52912246.85954761, Validation Loss: 29422431.2
Epoch 4, Training Loss: 22821705.16570226, Validation Loss: 21934850.0
Epoch 5, Training Loss: 20782540.18095739, Validation Loss: 21980772.4
Epoch 6, Training Loss: 20782165.60441873, Validation Loss: 22031612.8
Epoch 7, Training Loss: 20781856.248290375, Validation Loss: 21950381.0
Epoch 8, Training Loss: 20781392.246712256, Validation Loss: 21956635.4
Epoch 9, Training Loss: 20776616.859021567, Validation Loss: 21970061.2
Epoch 10, Training Loss: 20781443.142556548, Validation Loss: 21955459.8

It seems to plateau at ~2 * 10^7 mean squared error. Why could this be? Are there any obvious bugs in my code? What are some ideas on how I can get out of this plateau?