Is this a right implementation of perplexity?

Greetings everyone here

I am comparing perplexity from language models built using KenLM and GRU on the same datasets, i.e. same vocabulary. Initially, perplexity does not seem to be right for me when calculating it from the GRU model. I went on several iterations where I think I found the right formula to calculate it from the loss. What made this confusion is that I wanted to calculate the per-token perplexity just to be consistent with the perplexity calculation of KenLM (source of this can be found here
I just want to make sure my implementation is correct and as intended.

I wrote my model using PyTorch-lightening. The model code:

class LitNeurlaLanguageModel(LightningModule):
    def __init__(
        self,
        vocab_size,
        unk_token_id = 0,
        pad_token_id = 1,
        num_layers = NUM_LAYERS,
        hidden_size = HIDDEN_SIZE,
        dropout_prob = DROPOUT_PROB,
        learning_rate = LEARNING_RATE,
        embedding_size = EMBEDDING_SIZE,
    ):

        super().__init__()
        self.save_hyperparameters()

        self.pad_token_id = pad_token_id
        self.unk_token_id = unk_token_id
        
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob
        self.learning_rate = learning_rate
        self.embedding_size = embedding_size

        self.embedding_layer = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.embedding_size,
        )
        self.gru_layer = nn.GRU(
            input_size=self.embedding_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.first_dense_layer = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
        )
        self.dropout_layer = nn.Dropout(p=self.dropout_prob)
        self.relu = nn.ReLU()
        self.second_dense_layer = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.vocab_size,
        )

    def forward(self, x, hiddens=None):
       # in case hiddens are passed, useful in inference
        outputs = self.embedding_layer(x)
        if hiddens is None:
            outputs, hiddens = self.gru_layer(outputs)
        else:
            outputs, hiddens = self.gru_layer(outputs, hiddens)
        outputs = self.first_dense_layer(outputs)
        outputs = self.dropout_layer(outputs)
        outputs = self.relu(outputs)
        outputs = self.second_dense_layer(outputs)
        return outputs, hiddens


    def _get_loss(self, batch, ignore_oovs=False,loss_reduction='mean'):
        inputs, labels = batch
        # batch here, is as follows:
        # inputs: [[id1,id2, ...,id(k-1)],...,[id1,id2, ...,id(k-1)]], k is the sequence length
        # labels: [[id2,id3, ...,id(k)],...,[id2,id3, ...,id(k)]]
        outputs, hiddens = self(inputs)
        # https://discuss.pytorch.org/t/cross-entropy-loss-for-a-sequence-time-series-of-output/4309
        # reshaping outputs to size [-1,vocab_size]
        outputs = outputs.view(-1,self.vocab_size)
        # reshaping labels to [vocab_size]
        labels = labels.view(-1)
        if ignore_oovs:
            # it might be good to report the results without oovs sometimes? 
            # https://discuss.pytorch.org/t/when-to-use-ignore-index/5935/11
            labels[labels == self.unk_token_id] = self.pad_token_id
       # when calculating perplexity, make sure to pass 'none' as a loss_reduction
       # so that loss is averaged over tokens from all batches in the test set
       # ----
       # this loss calculation ignors padding to improve model performance
         loss = F.cross_entropy(outputs,labels,ignore_index=self.pad_token_id,reduction=loss_reduction)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_loss(batch)
        self.log(
            "loss",
            loss,
            on_step=True,
            on_epoch=False,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)
        return {"val_loss": self._get_loss(batch)}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
        )
        return optimizer

The method that calculates perplexity:

def calculate_perpelixity(
    lm_model,
    dataset,
    tokenizer,
    use_tqdm=True,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    ignore_oovs = False,
    sequence_length = SEQUENCE_LENGTH,
):
    # https://towardsdatascience.com/the-relationship-between-perplexity-and-entropy-in-nlp-f81888775ccc
    # https://stackoverflow.com/a/59219379/4412324
    lm_dataset = LanguageModelDataset(
        dataset=dataset,
        tokenizer=tokenizer,
    )
    dataloader = DataLoader(
        shuffle=False,
        dataset=lm_dataset,
        batch_size=batch_size,
        num_workers=CPU_COUNT,
        collate_fn=dataset_collate_fn,
        drop_last=True if len(lm_dataset) > batch_size else False,
    )
    lm_model.to(device)
    lm_model.eval()
    with torch.no_grad():
        loader = tqdm(dataloader) if use_tqdm else dataloader
        losses = []
        for batch in loader:
            inputs, outputs = batch
            inputs = inputs.to(device)
            outputs = outputs.to(device)
            batch_losses = lm_model._get_loss(
                (inputs, outputs),
                ignore_oovs=ignore_oovs,
                loss_reduction='none'
            )
            losses.extend(batch_losses)
    # stack tokens' losses from all batches
    losses = torch.stack(losses)
    perplexity = losses.mean().exp()
    return perplexity

When training this model on a small dataset of 77k overall tokens and 14.4 overall vocabs, the best perplexity I am getting from KenLM is around 265. However, the perplexity I am getting when training this GRU model is, 33, a significant but expected(?) improvement.

Is this imlementation a right implementation for perplexity given this setup?

I also noticed something. When using reduction='mean' when calculating perplexity, the results explodes! is there any explanation for that? Theoritically speacking, it should be the same as this will results in a two-steps mean reductions?

Help here is appreciated

Thanks Much,.

Perplexity is supposed to be P = exp(L) where L is the loss (typically a cross-entropy type - ylog(a)). It seems like you are getting the loss per batch (which is an average over all items in a batch) and then averaging over these losses. Assuming all your batches are the same size, this is the same as taking the loss over the entire dataset in data loader. This seems right? I have not dug around to see what KenLM is but is this: https://kheafield.com/papers/avenue/kenlm.pdf ? This model seems quite old now (2011?), so a GRU (2014+) would do better?

Where are you referring to “reduction=‘mean’” in your question? It seems you are using ‘mean’ so I’m unsure where you are referring to.

I think I was confused in this implementation. However, what you’ve mentioned @dreidizzle is right. Thanks for your comment. I really think that it is much consice to use the implementation of torchmetrics.