How to load a large text file into datasets for pretraining llm

I have enough RAM to load a large text but during the processing inside the Dataset logic, my kernel dies.

I tried to load the text file in chunks but the problem is i want long sequences on 256 and in some cases the train_loader returns empty since the chunk is not as long as 256.

This is my current code -

with open('data/shards/ta_dedup.txt_shard_1.txt', 'r', encoding='utf-8') as f:
    rawtext = f.read()

import multiprocessing
from typing import List, Tuple
import sentencepiece as spm
import torch
from torch.utils.data import DataLoader, Dataset, random_split

class TamilDataset(Dataset):
    def __init__(self, text: str, tokenizer: spm.SentencePieceProcessor, max_length: int, stride: int, debug: bool = False):
        """
        PyTorch Dataset for tokenized Tamil text.
         
        Args:
            path (str): Path to the text file.
            offset_dict (List[int]): List of line offsets.
            tokenizer (spm.SentencePieceProcessor): SentencePiece tokenizer.
            max_length (int): Maximum sequence length.
            stride (int): Stride for overlapping chunks.
        """
        self.rawtext = text
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.debug = debug

        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = self.tokenizer.encode(self.rawtext)

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
    
    def __len__(self) -> int:
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(
    rawtext: str, 
    batch_size: int, 
    max_length: int, 
    stride: int, 
    shuffle: bool, 
    drop_last: bool, 
    num_workers: int,
    train_split: float = 0.75
) -> Tuple[DataLoader, DataLoader]:
    """
    Create train and validation dataloaders with a specified split.
    
    Args:
        text (str): raw text.
        batch_size (int): Batch size for dataloaders.
        max_length (int): Maximum sequence length.
        stride (int): Stride for overlapping chunks.
        shuffle (bool): Whether to shuffle the data.
        drop_last (bool): Whether to drop the last incomplete batch.
        num_workers (int): Number of worker processes for data loading.
        train_split (float, optional): Proportion of data to use for training. Defaults to 0.75.
    
    Returns:
        Tuple of train and validation DataLoaders.
    """

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load('models/tok32000.model')

    
    dataset = TamilDataset(rawtext, tokenizer, max_length, stride, debug=True)
    
    train_size = int(len(dataset) * train_split)
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        drop_last=drop_last, 
        num_workers=num_workers
    )
    
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        drop_last=drop_last, 
        num_workers=num_workers
    )
    
    return train_dataloader, val_dataloader

train_dataloader, val_dataloader = create_dataloader(
    rawtext, 
    batch_size=1, 
    max_length=256, 
    stride=1, 
    shuffle=True, 
    drop_last=False, 
    num_workers=0
)

Would it work to define max_length as min(len(token_ids), max_length)?
This would make sure the created overlapping windows are not empty. (You might need to subtract the stride if needed).

1 Like

Hi @ptrblck thanks for your answer, I was able to get around with writing some lazy loader but the caveat im facing is since I do a lazy loader, im unable to implement the total length method, would you have any suggestions for this?

Thanks so much.

class LazyTamilDataset(IterableDataset):
    def __init__(self,
                 file_path: str,
                 tokenizer: spm.SentencePieceProcessor,
                 max_length: int,
                 stride: int,
                 split: Optional[str] = None,
                 train_ratio: float = 0.8,
                 seed: int = 42,
                 debug: bool = False):
        """
        Lazy loading dataset for large text files with memory-efficient tokenization and train/val split.
        
        Args:
            file_path (str): Path to the text file.
            tokenizer (spm.SentencePieceProcessor): SentencePiece tokenizer.
            max_length (int): Maximum sequence length.
            stride (int): Stride for overlapping chunks.
            split (str, optional): 'train' or 'val'. None uses the full dataset.
            train_ratio (float): Proportion of data to use for training.
            seed (int): Random seed for reproducibility.
            debug (bool): Enable debug mode.
        """
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.split = split
        self.train_ratio = train_ratio
        self.debug = debug

        random.seed(seed)
        torch.manual_seed(seed)

        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")

    def _generate_chunks(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Generator method to yield tokenized chunks lazily with train/val splitting.
        """
        with open(self.file_path, 'r', encoding='utf-8') as f:
            chunk_counter = 0
            while True:
                chunk = f.read(1024 * 1024)  # Read 1MB at a time
                if not chunk:
                    break

                token_ids = self.tokenizer.encode(chunk)

                for i in range(0, len(token_ids) - self.max_length, self.stride):
                    if self.split is not None:
                        # Deterministic split based on chunk counter
                        is_train = random.random() < self.train_ratio

                        if (self.split == 'train' and not is_train) or \
                           (self.split == 'val' and is_train):
                            chunk_counter += 1
                            continue

                    input_chunk = token_ids[i:i + self.max_length]
                    target_chunk = token_ids[i + 1: i + self.max_length + 1]

                    yield (
                        torch.tensor(input_chunk, dtype=torch.long),
                        torch.tensor(target_chunk, dtype=torch.long)
                    )

                    chunk_counter += 1

    def __iter__(self):
        """
        Make the dataset iterable for lazy loading.
        """
        return self._generate_chunks()

def create_lazy_split_dataloader(
    file_path: str,
    batch_size: int,
    max_length: int,
    stride: int,
    train_ratio: float = 0.8,
    num_workers: int = 0,
    seed: int = 42
) -> Tuple[DataLoader, DataLoader]:
    """
    Create lazy loading train and validation DataLoaders with a split.
    
    Args:
        file_path (str): Path to the text file.
        batch_size (int): Batch size for dataloaders.
        max_length (int): Maximum sequence length.
        stride (int): Stride for overlapping chunks.
        train_ratio (float): Proportion of data to use for training.
        num_workers (int): Number of worker processes for data loading.
        seed (int): Random seed for reproducibility.
    
    Returns:
        Tuple of train and validation DataLoaders.
    """

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load('models/tok32000.model')


    train_dataset = LazyTamilDataset(
        file_path = file_path,
        tokenizer = tokenizer,
        max_length = max_length,
        stride = stride,
        split = 'train',
        train_ratio = train_ratio,
        seed = seed
    )

    val_dataset = LazyTamilDataset(
        file_path = file_path,
        tokenizer = tokenizer,
        max_length = max_length,
        stride = stride,
        split = 'val',
        train_ratio = train_ratio,
        seed = seed
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size = batch_size,
        num_workers = num_workers
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size = batch_size,
        num_workers = num_workers
    )

    return train_dataloader, val_dataloader