Initial D_KL loss is high and going down really slow

I’m trying to create stand based on speculative decoding idea, where draft model will be finetuned llm that can predict target’s llm tokens on prompt. My idea is to ensure that the probability distribution for draft and target models is very close. That’s why I’m trying to use D_{KL}.

Getting large model’s logits is kinda expensive on every training step, so I’m preprocessing this data and saving to files. When I started training (±20000 examples for 125m model overall; maybe small, but I guess that’s not the main reason), my loss was going down from epoch to epoch really slow was about 3-4 on 10th epoch.

Maybe my loss calculating is wrong, maybe preprocessing is wrong, I can’t tell. I was trying to train the model on 1 batch to look at the adequacy of the loss and got loss around 0.1 after ±200 iterations.

Here are code snippets:

Some constants:

EPS = 1e-4
TARGET_MODEL="facebook/opt-13b"
DRAFT_MODEL="facebook/opt-1.3b"
SMALL_TARGET_MODEL="facebook/opt-1.3b"
SMALL_DRAFT_MODEL="facebook/opt-125m"
CUDA_DEVICE="cuda:1"

Datamodule:

from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import lightning as L
from torch.utils.data import DataLoader
from datasets import load_dataset
import os
import torch
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

from constants import CUDA_DEVICE

def collate_fn(batch, target_model=None, tokenizer=None):
    if target_model is not None:
        if tokenizer is None:
            raise Exception("You should provide tokenizer too!")
        input_ids_padded = pad_sequence(
            [item["input_ids"] for item in batch],
            batch_first=True,
            padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        )
        logits_padded = pad_sequence(
            [item["logits"].squeeze(0) for item in batch],
            batch_first=True,
            padding_value=0
        )
        return {
            "input_ids": input_ids_padded,
            "logits": logits_padded
        }
    return torch.utils.data.default_collate(batch)

class WikiTextV2Datamodule(L.LightningDataModule):
    def __init__(self, min_len: int, max_len: int, target_model = None, target_model_tokenizer = None, device=CUDA_DEVICE, num_workers: int = 0, batch_size: int = 16, check_cache=True) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.min_len = min_len 
        self.max_len = max_len 
        self.num_workers = num_workers
        self.target_model = target_model
        self.target_model_tokenizer = target_model_tokenizer
        self.device = device
        self.check_cache = check_cache
   
    def setup(self, stage) -> None:
        train_data = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="train")
        test_data = load_dataset("Salesforce/wikitext", "wikitext-2-v1", split="test")
        
        # If model is not None, batch has structure {"text": [N strings]} where N is batchsize    
        self.train_dataset = self.filter_dataset(train_data, self.min_len, self.max_len)
        self.val_dataset = self.filter_dataset(test_data, self.min_len, self.max_len)
        
        if self.target_model is not None:
            if self.target_model_tokenizer is None:
                raise Exception("You should provide target model tokenizer for fine-tuning too!")
            self.prepare_dataset_for_draft_model_finetuning()

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=lambda batch: collate_fn(batch, self.target_model, self.target_model_tokenizer)
        )

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=lambda batch: collate_fn(batch, self.target_model, self.target_model_tokenizer)
        )
    
    def prepare_dataset_for_draft_model_finetuning(self):
        cache_dir = "data"
        os.makedirs(cache_dir, exist_ok=True)
        train_cache_path = os.path.join(cache_dir, "train.pt")
        test_cache_path = os.path.join(cache_dir, "test.pt")
        
        if self.check_cache and os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
            print("Found cached data. Loading...")
            self.train_dataset = torch.load(train_cache_path)
            self.val_dataset = torch.load(test_cache_path)
            print(f"Loaded preprocessed data from cache")
        else:
            def process_dataset(dataset, desc):
                processed_data = []
                for item in tqdm(dataset, desc=desc):
                    input_text = item["text"]
                    
                    tokenized_input = self.target_model_tokenizer(input_text, return_tensors="pt").to(self.device)
                    input_ids = tokenized_input.input_ids
                    with torch.no_grad():
                        logits = self.target_model(input_ids[0]).logits
                    processed_data.append({
                        "input_ids": input_ids[0].cpu(),
                        "logits": logits.cpu()
                    })
                
                return processed_data
            
            self.train_dataset = process_dataset(self.train_dataset, "Processing train dataset")
            self.val_dataset = process_dataset(self.val_dataset, "Processing test dataset")
            
            # Save to cache
            torch.save(self.train_dataset, train_cache_path)
            torch.save(self.val_dataset, test_cache_path)
            print(f"Processed datasets and saved to cache")

    @staticmethod
    def filter_dataset(dataset, min_len: int, max_len: int) -> list[dict[str, str]]:
        return dataset.filter(lambda row: min_len <= len(row["text"].split()) <= max_len)

finetuning attempt:

import torch
import lightning as L
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn import functional as F
from constants import CUDA_DEVICE, DRAFT_MODEL, EPS, SMALL_DRAFT_MODEL, SMALL_TARGET_MODEL, TARGET_MODEL
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.optimization import Adafactor, AdafactorSchedule
from datamodule import WikiTextV2Datamodule
import os
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import lightning as L
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training


class Lit(L.LightningModule):
    def __init__(
        self, draft_model, learning_rate=1e-6, weight_decay=0.01,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['draft_model'])
        self.draft_model = draft_model
        self.draft_model.train()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
    
    def forward(self, *args, **kwargs):
        return self.draft_model(*args, **kwargs)
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        target_logits = batch["logits"]
        
        draft_logits = self.draft_model(input_ids).logits
        log_draft_probs = F.log_softmax(draft_logits, dim=-1)
        target_probs = F.softmax(target_logits, dim=-1)    
        
        loss = F.kl_div(log_draft_probs, target_probs, reduction='batchmean')
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        target_logits = batch["logits"]
    
        draft_logits = self.draft_model(input_ids).logits
        log_draft_probs = F.log_softmax(draft_logits, dim=-1)
        target_probs = F.softmax(target_logits, dim=-1)
        
        loss = F.kl_div(log_draft_probs, target_probs, reduction='batchmean')
        self.log("val_loss", loss, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [p for p in self.draft_model.parameters() if p.requires_grad],
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=10,
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler
        }

def create_peft_config(model):
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    return get_peft_model(prepare_model_for_kbit_training(model), peft_config)



if __name__ == "__main__":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch.set_float32_matmul_precision('medium')

    target_model = AutoModelForCausalLM.from_pretrained(SMALL_TARGET_MODEL).to(CUDA_DEVICE)
    target_model_tokenizer = AutoTokenizer.from_pretrained(SMALL_TARGET_MODEL)
    draft_model = create_peft_config(AutoModelForCausalLM.from_pretrained(SMALL_DRAFT_MODEL).to(CUDA_DEVICE))
    #target_model = AutoModelForCausalLM.from_pretrained(TARGET_MODEL).to(CUDA_DEVICE)
    #target_model_tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL)
    #draft_model = create_peft_config(AutoModelForCausalLM.from_pretrained(DRAFT_MODEL).to(CUDA_DEVICE))
    
    callbacks = [
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=False),
    ]
    
    # TODO: parameters for testing; dont forget to change them on training
    datamodule = WikiTextV2Datamodule(
        min_len=5,  
        max_len=20,
        target_model=target_model,
        target_model_tokenizer=target_model_tokenizer,
        device=CUDA_DEVICE,
        batch_size=8, 
        check_cache=True,
        num_workers=25
    )

    trainer = L.Trainer(
        accelerator="gpu", max_epochs=10, 
        limit_train_batches=None,
        logger=False,
        devices=[int(CUDA_DEVICE.split(":")[-1])],
        callbacks=callbacks,
    )

    fine_tuned_model = Lit(draft_model=draft_model, learning_rate=1e-6)
    trainer.fit(model=fine_tuned_model, datamodule=datamodule)

Thank you in advance.