LLM is not learning

Hello everyone,

In the past few weeks, I have programmed my own LLM in PyTorch, and after a lot of debugging, the code runs without any errors. However, after about 500 batches, the loss drops to around 0.1, but the actual text generation produces nothing meaningful. My data is stored in multiple .pth files, already tokenized and batched.

Sorry for the long text and the large amount of code, but I really don’t know what to do anymore.

import torch
import torch.nn as nn
import torch.optim as optim

from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from transformers import GPT2TokenizerFast

from tqdm import tqdm
import os
import time
import math
import random
import re
import matplotlib.pyplot as plt
import warnings

from dataclasses import dataclass

@dataclass
class Config:
#Hyperparameters
learning_rate = 1e-4
criterion = nn.CrossEntropyLoss()
tokenizer = GPT2TokenizerFast.from_pretrained(“gpt2”)

#Transformer parameters
src_vocab_size = len(tokenizer)
tgt_vocab_size = len(tokenizer) 
seq_lenght = 128
emb_dim = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1

#Model Parameters 
epochs = 5
batch_size = 32
data_dir = "C:\\Users\\Bennet\\Desktop\\Projekte\\token_batches"
checkpoint_path = "C:\\Users\\Bennet\\Desktop\\Projekte\\Encoder_Model1\\Checkpoints"
buffer_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#name == main schleife
num_workers = 0         #os.cpu_count()
pin_memory = True
How_many_epochs = 5

#checkpoint Variabel
Neuste_Datei = 0
token_file = 0

Config = Config()

Config.tokenizer.add_special_tokens({‘pad_token’: ‘[PAD]’})

def mask(size):
print(“mask aufgerufen”)
mask = torch.triu(torch.ones(size, size))
mask = mask.masked_fill(mask == 1, float(“-inf”))
return mask

def natural_sort_key(s):
print(“natural_sort_key aufgerufen”)
return [int(text) if text.isdigit() else text.lower() for text in re.split(‘([0-9]+)’, s)]

def positional_encoding():
print(“positinal Encoding Aufgerufen”)
pe = torch.zeros(Config.seq_lenght, Config.emb_dim)
for pos in range(Config.seq_lenght):
for i in range(0, Config.emb_dim, 2):
pe[pos, i] = math.sin(pos / (10000 ** (2*i / Config.emb_dim)))
pe[pos, i + 1] = math.cos(pos / (10000 ** (2 * i / Config.emb_dim)))
return pe

class TransformerModel(nn.Module):
def init(self):
print(“Transformer class aufgerufen”)
super(TransformerModel, self).init()

    self.scr_embedding = nn.Embedding(Config.src_vocab_size, Config.emb_dim)
    self.tgt_embedding = nn.Embedding(Config.tgt_vocab_size, Config.emb_dim)
    self.positional_Encoding = positional_encoding

    self.Transformer = nn.Transformer(
        d_model=Config.emb_dim,
        nhead=Config.nhead,
        num_encoder_layers=Config.num_encoder_layers,
        num_decoder_layers=Config.num_decoder_layers,
        dim_feedforward=Config.dim_feedforward,
        dropout=Config.dropout,
        batch_first=True
    )

    self.fc1 = nn.Linear(Config.emb_dim, Config.tgt_vocab_size)

def forward(self, src, tgt):
    src_embedded = self.scr_embedding(src) + pe[:src.size(1)]
    tgt_embedded = self.tgt_embedding(tgt) + pe[:tgt.size(1)]

    tgt_mask = mask(tgt.size(1)).to(Config.device)

    output = self.Transformer(
        src_embedded, tgt_embedded,
        tgt_mask=tgt_mask
    )

    return self.fc1(output)

pe = positional_encoding().to(Config.device)
model = TransformerModel().to(Config.device)
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)

def save_model_optimizer(epoch):
print(“save_model_optimizer aufgerufen”)
os.makedirs(Config.checkpoint_path, exist_ok=True)
file_name = f"Checkpoint_num{Config.Neuste_Datei}__{Config.token_file}.pt"
Checkpoint_path = os.path.join(Config.checkpoint_path, file_name)

checkpoints = {
    "model_state_dict": model.state_dict(),
    "optimizer_stat_dict": optimizer.state_dict(),
    "epoch": epoch,
    "token_file": Config.token_file
}

torch.save(checkpoints, Checkpoint_path)
print(f"Checkpoint unter {Checkpoint_path} gspeichert.")
Config.Neuste_Datei += Config.Neuste_Datei

def load_model_optimizer():
print(“load_model_optimizer aufgerufen.”)
# Funktion zur Extraktion der Epochenzahl aus dem Dateinamen
def get_latest_Checkpoint(filename):
match = re.search(r"Checkpoint_num(\d+)__(\d+).pt", filename)
return int(match.group(1)) if match else 0 # Falls kein Match, epoch=0

# Alle Checkpoint-Dateien auflisten
checkpoint_files = [f for f in os.listdir(Config.checkpoint_path) if f.startswith("Checkpoint_epoch") and f.endswith(".pt")]

if not checkpoint_files:
    print("Kein Checkpoint Datei gefunden.")
    return 0, 0  #Start_epoche, Current_file


# Neueste Checkpoint-Datei finden (höchste Epochenzahl)
latest_checkpoint = max(checkpoint_files, key=get_latest_Checkpoint)
checkpoint_path = os.path.join(Config.checkpoint_path, latest_checkpoint)

# Checkpoint laden
print(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)

# Zustand wiederherstellen
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_stat_dict"])
epoch = checkpoint["epoch"]
token_file = checkpoint["token_file"]

print(f"Gewichte wurden von  {checkpoint_path} geladen. (Epoch: {epoch})")
print(token_file)
return token_file, epoch  # Das Modell, den Optimizer und die Epochenzahl zurückgeben

class train_data_loader(IterableDataset):
def init(self):
print(“train_data_loader aufgerufen”)
super().init()
self.Aktuelle_Datei, eg = load_model_optimizer()
self.files = sorted(
[os.path.join(Config.data_dir, f)
for f in os.listdir(Config.data_dir)
if f.startswith(“tokens_batch”) and f.endswith(“.pt”)],
key=natural_sort_key) # Verwende die benutzerdefinierte Sortierfunktion

    print(f"Daten Laden mit der {self.Aktuelle_Datei} Datei.")


def __iter__(self):
    buffer = []
    print(self.Aktuelle_Datei)
    for self.Aktuelle_Datei in range(self.Aktuelle_Datei, len(self.files)):
        file_path = os.path.join(Config.data_dir, f"tokens_batch_{str(self.Aktuelle_Datei)}.pt")    
        print(f"Der file path: {file_path}")

        tokens = torch.load(file_path, weights_only=True)
        tokens = tokens.cpu()


        for i in range(0, len(tokens) - Config.seq_lenght, Config.seq_lenght):
            buffer.append(tokens[i : i + Config.seq_lenght].clone().detach().long())
            if len(buffer) >= Config.buffer_size:
                random.shuffle(buffer)
                while buffer:
                    yield buffer.pop()

    random.shuffle(buffer)
    while buffer:
        yield buffer.pop()

dataset = train_data_loader()

def train(train_loader):
print(“train aufgerufen”)
with open(“Loss.txt”, “a”) as Loss_file:
token_file, start_epoch = load_model_optimizer()
print(f"Traing startet in der {start_epoch + 1} mit der datei {token_file}")
Config.token_file = token_file
for epoch in range(start_epoch, start_epoch + Config.How_many_epochs):
model.train()
epoch_Loss = 0
batch_id = 0
every_User_Output = 10
batch_without_User_Output = 0
Losses =
batch_ids =
for batch in train_loader:
#Datenvorbereiten
batch = batch.to(Config.device)
src = batch[:, :-1]
tgt = batch[:, 1:]

            #Output berechnen
            optimizer.zero_grad()
            output = model(src, tgt)


            Loss = Config.criterion(output.reshape(-1, output.shape[-1]), tgt.reshape(-1))
            Loss.backward()

            optimizer.step()

            epoch_Loss += Loss.item()
            batch_id = batch_id + 1

            Losses.append(Loss.item())
            batch_ids.append(batch_id)

            if every_User_Output <= batch_without_User_Output:
                Predicted_Token = output.argmax(dim=-1)
                User_Output = Config.tokenizer.batch_decode(Predicted_Token, skip_speial_tokens=True)
                print(f"Decoded Output {User_Output}")
                batch_without_User_Output = 0 
                save_model_optimizer(epoch)

            print(f"In der {epoch + 1} Epoche beträgt das Loss im batch {batch_id}, {Loss.item()}.Letzter User Output vor {batch_without_User_Output} Batches.")

            Loss_file.write(f"In der {epoch + 1} Epoche beträgt das Loss im batch {batch_id}, {Loss.item()}\n")
            batch_without_User_Output += 1
            batch_id += 1

def generate_text(input_text, start_tocken=2, end_tocken=3):
print(“generate_text aufgerufen”)
eg, eg2 = load_model_optimizer()
model.eval()

src = Config.tokenizer.encode(input_text, return_tensors = "pt").to(Config.device)
tgt = torch.tensor([[start_tocken]], dtype=torch.long, device=Config.device)
for _ in range(1000):

    logits = model(src, tgt)
    next_token_logits = logits[:, -1, :]
    next_token = next_token_logits.argmax(dim=1, keepdim=True)

    tgt = torch.cat([tgt,next_token],dim=1)

    if next_token.item() == end_tocken:
        break
generated_text = Config.tokenizer.decode(tgt[0].tolist())
print(f"Generated text: {generated_text}")

def UI():
print(“LLM was möchtest du machen?”)
print(“1)trainiren”)
print(“2) text gereriren”)
wal = input()
if int(wal) == 1:
train_loader = DataLoader(dataset, batch_size=Config.batch_size, num_workers=Config.num_workers, pin_memory=Config.pin_memory)
train(train_loader=train_loader)
elif int(wal) == 2:
generate_text(“Ich bin ganz schön”)
else:
print(“Fehler in der wal”)

UI()

Welcome to the PyTorch forums!

In order to not be reinventing the wheel, I’d suggest you look at the Huggingface library and tutorials. They are built on top of PyTorch but are more specific to LLMs. In fact, you might find it much less resource intensive to take a pre-existing trained model and then finetune your own set of LoRA weights to your data. This has very low overhead and can train quickly and efficiently.

Here are their tutorials: Introduction - Hugging Face NLP Course

And here is a crash course for training LoRA weights: LoRA

Lastly, in order to make your code more legible, please wrap it in 3 ticks on the front and back. The tick is the character above this ~.

1 Like

Unfortunately, in the context of neural networks, this does not mean that much. As long as the dimensions of the tensors work out, PyTorch won’t throw an error. However, PyToch cannot check if content of your tensors is correct.

Have you tried using a single encoder and decoder block, i.e., num_encoder_layers = 1 and num_decoder_layers = 1. This will significantly lower the total number of trainable parameters, and thus increase the chance for the loss going down after not that many iterations.

What “kind” of LLM are you trying to train and what does your data look like? LLMs such as GPT, LLaMA, and most others are decoder-only architectures for. However, you are using a encoder and decoder.