ok, I’ve written a self-contained example to illustrate my point (I couldn’t make it smaller than this)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding, LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
torch.manual_seed(0)
seqs = ['long_str', 'medium', 'tiny'] # already ordered by length
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
inputs = [s[:len(s) - 1] for s in seqs] # no prediction for final char
targets = [s[1:len(s)] for s in seqs] # predict next char
vectorized_inps = [[vocab.index(tok) for tok in seq] for seq in inputs]
vectorized_tgts = [[vocab.index(tok) for tok in seq] for seq in targets]
class RnnLangModel(nn.Module):
def __init__(self):
super().__init__()
self.embed = Embedding(len(vocab), 4) # embedding_dim = 4
self.lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
self.decoder = nn.Linear(5, len(vocab)) # output 1 logit for each element in vocab
def forward(self, inp):
embedded_padded_input = self.embed(inp)
packed_input = pack_padded_sequence(embedded_padded_input, seq_lengths, batch_first=True)
packed_output, (ht, ct) = self.lstm(packed_input)
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
return self.decoder(output)
def forward_no_packing(self, inp):
embedded_padded_input = self.embed(inp)
output, (ht, ct) = self.lstm(embedded_padded_input)
return self.decoder(output)
rnn_lm = RnnLangModel()
# prepare input
seq_lengths = torch.LongTensor(list(map(len, vectorized_inps))) # get the length of each seq
padded_input = torch.zeros((len(vectorized_inps), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_inps, seq_lengths)):
padded_input[idx, :seqlen] = torch.LongTensor(seq)
# prepare output
padded_targets = torch.zeros((len(vectorized_tgts), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_tgts, seq_lengths)):
padded_targets[idx, :seqlen] = torch.LongTensor(seq)
output = rnn_lm(padded_input)
output_no_packing = rnn_lm.forward_no_packing(padded_input)
ce_loss1 = F.cross_entropy(output.view(-1, output.shape[-1]), # flatten output, except for logits dim
padded_targets.view(-1),
ignore_index=0)
ce_loss2 = F.cross_entropy(output_no_packing.view(-1, output.shape[-1]),
padded_targets.view(-1),
ignore_index=0)
ce_loss3_wrong = F.cross_entropy(output.view(-1, output.shape[-1]),
padded_targets.view(-1),
ignore_index=-100)
ce_loss1.backward()
# ce_loss2.backward()
# ce_loss3_wrong.backward()
Whenever the goal is to make 1 discrete prediction per sequence-element based only on previous context, we have a setup like the one here.
In these cases, sequences need to be padded, just like the targets.
Therefore, when it’s time to compute cross-entropy, it requires using ignore_index
, otherwise we end up with ce_loss3_wrong
, even if we use pack_padded_sequence
.
Since we have to use ignore_index
, I was wondering what’s the need for using pack_padded_sequence
at all. Apparently it’s none, since ce_loss1
and ce_loss2
provide the same loss and the same gradient updates to the model (if you inspect output
there will be some differences but they are ignored by F.cross_entropy
). There might still be some computation saved by using pack_padded_sequence
, but I’m not even sure of that, since everything is done in parallel.
Now, I’m obviously not claiming that pack_padded_sequence
is useless, it can be great for instance if we only need 1 final hidden state for each sequence. But in this setting (which is not so rare) it seems a bit redundant to use it.
Btw the code was loosely inspired by this example.