Passing a minibatchs of sequential data through a bidirectional rnn

Hi, can someone please help me by explaining how to correctly pass minibatchs of sequential data through a bidirectional rnn? And perhaps show an example, if possible?

I will try to provide some context for my problem:

The problem is similar to a language modeling task. I’m trying to predict the next item for each item in a sequence.
In my case the input data to the model is a minibatch of N sentences with varying length. Each sentence consist of word indices representing a word in the vocabulary:

sents = [[4, 545, 23, 1], [34, 84], [23, 6, 774]]

The sentences in the dataset are randomly shuffled before creating minibatches.
Here is how the minibatches are created:

def batches(data, batch_size):
    """ Yields batches of sentences from 'data', ordered on length. """
    random.shuffle(data)
    for i in range(0, len(data), batch_size):
        sentences = data[i:i + batch_size]
        sentences.sort(key=lambda l: len(l), reverse=True)
        yield [torch.LongTensor(s) for s in sentences]

The model predicts the next element in the sentence. So the input and target looks like this:

input_sentence = [1, 4, 5, 7]
target_sentence = [4, 5, 7, 9]

Packed sequences are used in order to handle sentences of varying length Here is how the input and target are created:

x = nn.utils.rnn.pack_sequence([s[:-1] for s in sents])
y = nn.utils.rnn.pack_sequence([s[1:] for s in sents])

This input x, consisting of a minibatch of sentences, is then sent through the forward pass of the model:

out = model(x)

The model itself:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    """ A language model RNN with GRU layer(s). """

    def __init__(self, vocab_size, embedding_dim, hidden_dim, gru_layers, dropout):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.recurrent_layer = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim,  num_layers=gru_layers, dropout=dropout, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim*2, vocab_size)

    def forward(self, packed_sents):
        """ Takes a PackedSequence of sentences tokens that has T tokens
        belonging to vocabulary V. Outputs predicted log-probabilities
        for the token following the one that's input in a tensor shaped
        (T, |V|).
        """
        embedded_sents = nn.utils.rnn.PackedSequence(self.embedding(packed_sents.data), packed_sents.batch_sizes)
        out_packed_sequence, hidden = self.recurrent_layer(embedded_sents)
        out = self.fc1(out_packed_sequence.data)
        return F.log_softmax(out, dim=1)

The output of the model is a probability distribution over the unique items in the dataset, representing how likely they are to represent the next item in the sequence for each item in the sequence.

Around 100 000 sequences are used for training the model, and 100 000 are used for testing the model.

The model is evaluated using top-k accuracy/precision@K:

def topk_accuracy(output_distribution, targets, k):
    _, pred = torch.topk(input=output_distribution, k=k, dim=1)
    pred = pred.t()
    correct = pred.eq(targets.expand_as(pred))
    return correct.sum().item() / targets.shape[0]

I have evaluated the model by using different variations of the recurrent layer. That is GRU, LSTM, bidirectional GRU and Bidirectional LSTM.

The code for training the model:
https://github.com/aksellk/recsys_telenor/blob/master/prediction/main.py

The model itself:
https://github.com/aksellk/recsys_telenor/blob/master/prediction/model.py

The code for evaluating the model:
https://github.com/aksellk/recsys_telenor/blob/master/prediction/testing.py

When bidirectional GRU/LSTM is used, the top1-accuracy is around 94%
When plain GRU/LSTM is used, the top1-accuracy is around 37%
I suspect something in my experiments is wrong because the bidirectional model achieves too good results compared to the plain versions of GRU/LSTM.

I have also made an experiment where I calculate the average precision@1 for each item at specific positions in the sessions. For example, the average precision@1 for the first item in the session over all sessions (Then for the second position, the third, and so on). The idea is to investigate how the prediction of the model changes throughout the sessions.

The code for this experiment is in a jupyter notebook:
https://github.com/aksellk/recsys_telenor/blob/master/prediction/average_precision_for_the_nth_interaction.ipynb

Here is a plot from this experiment showing the top1-accuracy of the bidirectional LSTM model and the plain LSTM model. The x-axis represents the position of the items in the sequences:

I suspect there is something wrong with the experiment because the bidirectional model shows a top1-accuracy over 90% for each item in the sequence, which I believe is too promising.
Further, the top-1 accuracy for the first items in each sequence should be affected by the cold-start problem. For the plain LSTM, we see that the average top1-accuarcy for the first item is much lower compared to the rest of the items. But this is not the case for the bidirectional problem, it looks like it is too high.

So to sum up, I believe there is something wrong in how I pass sequences in minibatches through the bidirectional model and therefore ask for help for understanding how to do this properly. And possibly any ideas as to why the bidirectional variation of the model shows so high topK-accuarcy.

This question is related to a previously asked question on the forum:
https://discuss.pytorch.org/t/handling-the-hidden-state-with-minibatches-in-a-rnn-for-language-modelling/44261/4

If something is unclear, please feel free to ask me.
Any help is very much appreciated.
Thanks!

Here’s what I think(!) is going on:

You want to train a network that takes a sequence of words as input to predict the next word (many-to-one). But what you’re actually training is a many-to-many network, more specifically a sequence tagging network like for POS or NER tagging.

Given your example, you’re data item should look more like:

input_sentence = [1, 4, 5, 7]
target_word = 9

Why should the network learn that 1 maps to 4, 4 maps to 5, and so on. That’s not the task, but it affects your loss and hence what your network learns.

In my opinion, that would also explain why the Bi-LSTM is so much better than the LSTM. Since the Bi-LSTM starts also from the end of the sentence – given a sentence of length N – it will learn that for the last N-1 steps, the last input will be the next target. The simple LSTM would have to look into the future for that.

So this is what I would try:

  • Change your dataset such that one data item is a sequence as input and a single word as output, and treat it like a classification (many-to-one) task
  • Take hidden and not out_packed_sequence as input for the fc1 layer! Since you only want the last state for many-to-one, the last output in case of Bi-LSTM are on “opposite ends” (see also my post). So it’s much simpler to use hidden. Here’s an example for an RNN-based classifier that might help. In your case, label_size will also be vocab_size.

I hope that helps.

1 Like

Thank you, @vdw! I believe you are right. The data item should be represented like you said:

input_sentence = [1, 4, 5, 7]
target_word = 9

But I also think the prediction should be repeated for each item in the sequence. The reason why I think this is because if we only consider the last element in the sequence, it is the element with the longest sequence of previous elements. So in order to make a fair prediction, it should be repeated for each of the next items in the sequences. Splitting the sequences like this:

input_sentence = [1]
target_word = 4

input_sentence = [1, 4]
target_word = 5

input_sentence = [1, 4, 5]
target_word = 7

input_sentence = [1, 4, 5, 7]
target_word = 9

But I think splitting sequences like this can be hard to combine with minibatching during training. Especially since the sequences have varying lengths.

My idea is therefore to carry out this sequence split before I train and test the model. That is before loading the dataset. And it would lead to that the size of the dataset would increase. I believe this would make it easier to combine session splitting with minibatching. Do you think that would make sense?

This is how I would create the dataset:

'''
Splitting the sequences and adds more training data like this:

[1, 2]
[1, 2, 3]
[1, 2, 3, 4]
'''
def next_item_sequence_split(data):
    new_data = []
    for d in data:
        splits = [torch.LongTensor(d[:i]) for i in range(2,len(d)+1)]
        new_data.extend(splits)
    random.shuffle(new_data)
    return new_data

'''
Splitting the data into 80% training and 20% testing
'''
def split_dataset(dataset):
    random.shuffle(dataset)
    train_size = int(0.8 * len(dataset))
    train = dataset[0:train_size]
    test = dataset[train_size:-1]
    return train, test 

# split data into training and testing, x is a list of sequences
train, test = split_dataset(x)

x_train = next_item_sequence_split(train)
x_test = next_item_sequence_split(test)

Thanks!

If the splitting makes sense depends on your data and task.

You can have a look at this thread to generate batches of sequences of equal length. I use this all the time since it’s worry-free: neither padding nor packing is needed.

Thank you! I will have a look at it.