Using cross_entropy's ignore_index, instead of pack_padded_sequence

Suppose I’m using cross_entropy loss to do language modelling (to predict the next element in a sequence).

I have sequences with different lengths that I want to batch together, and the usual solution is to order them, pad with a special symbol (say 0), then use pack_padded_sequence(), feed them to an RNN and then .pad_packed_sequence().

However, for the loss function .cross_entropy() to ignore the padding, I should also use .cross_entropy(output_padded, ignore_index=0).

My question is – what happens if I do all of this, but without using pack_padded_sequence() and .pad_packed_sequence()? If .cross_entropy(output_padded, ignore_index=0) already ignores all 0’s, and assuming that 0’s only occur at the end of sequences, I don’t see how the outcome should differ.

They are 2 different concepts.

Sequences are packed to ignore padding and improve the efficiency and even accuracy (intuitively, padding adds noise) in return.

And the loss function is related to the final prediction. Here you no longer care about padding, you need padding just to pass the RNN step. ignore_index=0 in CrossEntropyLoss means if true label is 0 then this particular loss is ignored (does not contribute to the already accumulated loss).

In the example below, the first instance (logits[0]) does not contribute to the final loss because its true label is 0 and we ignore such instances:

torch.nn.CrossEntropyLoss(ignore_index=0)(torch.randn(2, 3), torch.tensor([0, 2]))

But if all my zeros occur exactly in the end of each sentence, wouldn’t that imply that anything happening during the padding symbols (i.e. anything after the end of sentence) is ignored by the loss function?

In the particular setting I’m describing, I don’t see any difference between the two. One asks to ignore symbols=0, the other asks to ignore symbols after end-of-sentence (which are 0’s)

Maybe I should add that I have one final prediction for each element of the sequence (since we are trying to predict the next element, for each position in the sequence)

The computed loss on some position let’s call it t is ignored only if true label on that position t is equal to the specified ignore_index, so if you match true labels that are equal to ignore_index with padding positions then you can ignore those predictions (if you match them correctly somehow). There is no magic between these 2 concepts i.e. nn.CrossEntropyLoss and ignore_index are unrelated to padding.

It’s possible that I misunderstood you (there are all different kinds of language modeling). What’s your input/output?

ok, I’ve written a self-contained example to illustrate my point (I couldn’t make it smaller than this)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding, LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

seqs = ['long_str', 'medium', 'tiny']  # already ordered by length
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
inputs = [s[:len(s) - 1] for s in seqs]  # no prediction for final char
targets = [s[1:len(s)] for s in seqs]  # predict next char
vectorized_inps = [[vocab.index(tok) for tok in seq] for seq in inputs]
vectorized_tgts = [[vocab.index(tok) for tok in seq] for seq in targets]

class RnnLangModel(nn.Module):
    def __init__(self):
        self.embed = Embedding(len(vocab), 4)  # embedding_dim = 4
        self.lstm = LSTM(input_size=4, hidden_size=5, batch_first=True)  # input_dim = 4, hidden_dim = 5
        self.decoder = nn.Linear(5, len(vocab))  # output 1 logit for each element in vocab

    def forward(self, inp):
        embedded_padded_input = self.embed(inp)
        packed_input = pack_padded_sequence(embedded_padded_input, seq_lengths, batch_first=True)
        packed_output, (ht, ct) = self.lstm(packed_input)
        output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
        return self.decoder(output)

    def forward_no_packing(self, inp):
        embedded_padded_input = self.embed(inp)
        output, (ht, ct) = self.lstm(embedded_padded_input)
        return self.decoder(output)

rnn_lm = RnnLangModel()

# prepare input
seq_lengths = torch.LongTensor(list(map(len, vectorized_inps)))  # get the length of each seq
padded_input = torch.zeros((len(vectorized_inps), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_inps, seq_lengths)):
    padded_input[idx, :seqlen] = torch.LongTensor(seq)
# prepare output
padded_targets = torch.zeros((len(vectorized_tgts), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_tgts, seq_lengths)):
    padded_targets[idx, :seqlen] = torch.LongTensor(seq)

output = rnn_lm(padded_input)
output_no_packing = rnn_lm.forward_no_packing(padded_input)

ce_loss1 = F.cross_entropy(output.view(-1, output.shape[-1]),  # flatten output, except for logits dim

ce_loss2 = F.cross_entropy(output_no_packing.view(-1, output.shape[-1]),

ce_loss3_wrong = F.cross_entropy(output.view(-1, output.shape[-1]),

# ce_loss2.backward()
# ce_loss3_wrong.backward()

Whenever the goal is to make 1 discrete prediction per sequence-element based only on previous context, we have a setup like the one here.

In these cases, sequences need to be padded, just like the targets.
Therefore, when it’s time to compute cross-entropy, it requires using ignore_index, otherwise we end up with ce_loss3_wrong, even if we use pack_padded_sequence.

Since we have to use ignore_index, I was wondering what’s the need for using pack_padded_sequence at all. Apparently it’s none, since ce_loss1 and ce_loss2 provide the same loss and the same gradient updates to the model (if you inspect output there will be some differences but they are ignored by F.cross_entropy). There might still be some computation saved by using pack_padded_sequence, but I’m not even sure of that, since everything is done in parallel.

Now, I’m obviously not claiming that pack_padded_sequence is useless, it can be great for instance if we only need 1 final hidden state for each sequence. But in this setting (which is not so rare) it seems a bit redundant to use it.

Btw the code was loosely inspired by this example.