Hello everyone,
In the past few weeks, I have programmed my own LLM in PyTorch, and after a lot of debugging, the code runs without any errors. However, after about 500 batches, the loss drops to around 0.1, but the actual text generation produces nothing meaningful. My data is stored in multiple .pth
files, already tokenized and batched.
Sorry for the long text and the large amount of code, but I really don’t know what to do anymore.
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import os
import time
import math
import random
import re
import matplotlib.pyplot as plt
import warnings
from dataclasses import dataclass
@dataclass
class Config:
#Hyperparameters
learning_rate = 1e-4
criterion = nn.CrossEntropyLoss()
tokenizer = GPT2TokenizerFast.from_pretrained(“gpt2”)
#Transformer parameters
src_vocab_size = len(tokenizer)
tgt_vocab_size = len(tokenizer)
seq_lenght = 128
emb_dim = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1
#Model Parameters
epochs = 5
batch_size = 32
data_dir = "C:\\Users\\Bennet\\Desktop\\Projekte\\token_batches"
checkpoint_path = "C:\\Users\\Bennet\\Desktop\\Projekte\\Encoder_Model1\\Checkpoints"
buffer_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#name == main schleife
num_workers = 0 #os.cpu_count()
pin_memory = True
How_many_epochs = 5
#checkpoint Variabel
Neuste_Datei = 0
token_file = 0
Config = Config()
Config.tokenizer.add_special_tokens({‘pad_token’: ‘[PAD]’})
def mask(size):
print(“mask aufgerufen”)
mask = torch.triu(torch.ones(size, size))
mask = mask.masked_fill(mask == 1, float(“-inf”))
return mask
def natural_sort_key(s):
print(“natural_sort_key aufgerufen”)
return [int(text) if text.isdigit() else text.lower() for text in re.split(‘([0-9]+)’, s)]
def positional_encoding():
print(“positinal Encoding Aufgerufen”)
pe = torch.zeros(Config.seq_lenght, Config.emb_dim)
for pos in range(Config.seq_lenght):
for i in range(0, Config.emb_dim, 2):
pe[pos, i] = math.sin(pos / (10000 ** (2*i / Config.emb_dim)))
pe[pos, i + 1] = math.cos(pos / (10000 ** (2 * i / Config.emb_dim)))
return pe
class TransformerModel(nn.Module):
def init(self):
print(“Transformer class aufgerufen”)
super(TransformerModel, self).init()
self.scr_embedding = nn.Embedding(Config.src_vocab_size, Config.emb_dim)
self.tgt_embedding = nn.Embedding(Config.tgt_vocab_size, Config.emb_dim)
self.positional_Encoding = positional_encoding
self.Transformer = nn.Transformer(
d_model=Config.emb_dim,
nhead=Config.nhead,
num_encoder_layers=Config.num_encoder_layers,
num_decoder_layers=Config.num_decoder_layers,
dim_feedforward=Config.dim_feedforward,
dropout=Config.dropout,
batch_first=True
)
self.fc1 = nn.Linear(Config.emb_dim, Config.tgt_vocab_size)
def forward(self, src, tgt):
src_embedded = self.scr_embedding(src) + pe[:src.size(1)]
tgt_embedded = self.tgt_embedding(tgt) + pe[:tgt.size(1)]
tgt_mask = mask(tgt.size(1)).to(Config.device)
output = self.Transformer(
src_embedded, tgt_embedded,
tgt_mask=tgt_mask
)
return self.fc1(output)
pe = positional_encoding().to(Config.device)
model = TransformerModel().to(Config.device)
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
def save_model_optimizer(epoch):
print(“save_model_optimizer aufgerufen”)
os.makedirs(Config.checkpoint_path, exist_ok=True)
file_name = f"Checkpoint_num{Config.Neuste_Datei}__{Config.token_file}.pt"
Checkpoint_path = os.path.join(Config.checkpoint_path, file_name)
checkpoints = {
"model_state_dict": model.state_dict(),
"optimizer_stat_dict": optimizer.state_dict(),
"epoch": epoch,
"token_file": Config.token_file
}
torch.save(checkpoints, Checkpoint_path)
print(f"Checkpoint unter {Checkpoint_path} gspeichert.")
Config.Neuste_Datei += Config.Neuste_Datei
def load_model_optimizer():
print(“load_model_optimizer aufgerufen.”)
# Funktion zur Extraktion der Epochenzahl aus dem Dateinamen
def get_latest_Checkpoint(filename):
match = re.search(r"Checkpoint_num(\d+)__(\d+).pt", filename)
return int(match.group(1)) if match else 0 # Falls kein Match, epoch=0
# Alle Checkpoint-Dateien auflisten
checkpoint_files = [f for f in os.listdir(Config.checkpoint_path) if f.startswith("Checkpoint_epoch") and f.endswith(".pt")]
if not checkpoint_files:
print("Kein Checkpoint Datei gefunden.")
return 0, 0 #Start_epoche, Current_file
# Neueste Checkpoint-Datei finden (höchste Epochenzahl)
latest_checkpoint = max(checkpoint_files, key=get_latest_Checkpoint)
checkpoint_path = os.path.join(Config.checkpoint_path, latest_checkpoint)
# Checkpoint laden
print(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
# Zustand wiederherstellen
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_stat_dict"])
epoch = checkpoint["epoch"]
token_file = checkpoint["token_file"]
print(f"Gewichte wurden von {checkpoint_path} geladen. (Epoch: {epoch})")
print(token_file)
return token_file, epoch # Das Modell, den Optimizer und die Epochenzahl zurückgeben
class train_data_loader(IterableDataset):
def init(self):
print(“train_data_loader aufgerufen”)
super().init()
self.Aktuelle_Datei, eg = load_model_optimizer()
self.files = sorted(
[os.path.join(Config.data_dir, f)
for f in os.listdir(Config.data_dir)
if f.startswith(“tokens_batch”) and f.endswith(“.pt”)],
key=natural_sort_key) # Verwende die benutzerdefinierte Sortierfunktion
print(f"Daten Laden mit der {self.Aktuelle_Datei} Datei.")
def __iter__(self):
buffer = []
print(self.Aktuelle_Datei)
for self.Aktuelle_Datei in range(self.Aktuelle_Datei, len(self.files)):
file_path = os.path.join(Config.data_dir, f"tokens_batch_{str(self.Aktuelle_Datei)}.pt")
print(f"Der file path: {file_path}")
tokens = torch.load(file_path, weights_only=True)
tokens = tokens.cpu()
for i in range(0, len(tokens) - Config.seq_lenght, Config.seq_lenght):
buffer.append(tokens[i : i + Config.seq_lenght].clone().detach().long())
if len(buffer) >= Config.buffer_size:
random.shuffle(buffer)
while buffer:
yield buffer.pop()
random.shuffle(buffer)
while buffer:
yield buffer.pop()
dataset = train_data_loader()
def train(train_loader):
print(“train aufgerufen”)
with open(“Loss.txt”, “a”) as Loss_file:
token_file, start_epoch = load_model_optimizer()
print(f"Traing startet in der {start_epoch + 1} mit der datei {token_file}")
Config.token_file = token_file
for epoch in range(start_epoch, start_epoch + Config.How_many_epochs):
model.train()
epoch_Loss = 0
batch_id = 0
every_User_Output = 10
batch_without_User_Output = 0
Losses =
batch_ids =
for batch in train_loader:
#Datenvorbereiten
batch = batch.to(Config.device)
src = batch[:, :-1]
tgt = batch[:, 1:]
#Output berechnen
optimizer.zero_grad()
output = model(src, tgt)
Loss = Config.criterion(output.reshape(-1, output.shape[-1]), tgt.reshape(-1))
Loss.backward()
optimizer.step()
epoch_Loss += Loss.item()
batch_id = batch_id + 1
Losses.append(Loss.item())
batch_ids.append(batch_id)
if every_User_Output <= batch_without_User_Output:
Predicted_Token = output.argmax(dim=-1)
User_Output = Config.tokenizer.batch_decode(Predicted_Token, skip_speial_tokens=True)
print(f"Decoded Output {User_Output}")
batch_without_User_Output = 0
save_model_optimizer(epoch)
print(f"In der {epoch + 1} Epoche beträgt das Loss im batch {batch_id}, {Loss.item()}.Letzter User Output vor {batch_without_User_Output} Batches.")
Loss_file.write(f"In der {epoch + 1} Epoche beträgt das Loss im batch {batch_id}, {Loss.item()}\n")
batch_without_User_Output += 1
batch_id += 1
def generate_text(input_text, start_tocken=2, end_tocken=3):
print(“generate_text aufgerufen”)
eg, eg2 = load_model_optimizer()
model.eval()
src = Config.tokenizer.encode(input_text, return_tensors = "pt").to(Config.device)
tgt = torch.tensor([[start_tocken]], dtype=torch.long, device=Config.device)
for _ in range(1000):
logits = model(src, tgt)
next_token_logits = logits[:, -1, :]
next_token = next_token_logits.argmax(dim=1, keepdim=True)
tgt = torch.cat([tgt,next_token],dim=1)
if next_token.item() == end_tocken:
break
generated_text = Config.tokenizer.decode(tgt[0].tolist())
print(f"Generated text: {generated_text}")
def UI():
print(“LLM was möchtest du machen?”)
print(“1)trainiren”)
print(“2) text gereriren”)
wal = input()
if int(wal) == 1:
train_loader = DataLoader(dataset, batch_size=Config.batch_size, num_workers=Config.num_workers, pin_memory=Config.pin_memory)
train(train_loader=train_loader)
elif int(wal) == 2:
generate_text(“Ich bin ganz schön”)
else:
print(“Fehler in der wal”)
UI()