How to fill in the blank using bidirectional RNN and pytorch?


#1

I am trying to fill in the blank using a bidirectional RNN and pytorch.

The input will be like: The dog is _____, but we are happy he is okay.

The output will be like:

1. hyper (Perplexity score here) 
2. sad (Perplexity score here) 
3. scared (Perplexity score here)

I discovered this idea here: https://medium.com/@plusepsilon/the-bidirectional-language-model-1f3961d1fb27

import torch, torch.nn as nn
from torch.autograd import Variable

text = ['BOS', 'How', 'are', 'you', 'EOS']
seq_len = len(text)
batch_size = 1
embedding_size = 1
hidden_size = 1
output_size = 1

random_input = Variable(
    torch.FloatTensor(seq_len, batch_size, embedding_size).normal_(), requires_grad=False)

bi_rnn = torch.nn.RNN(
    input_size=embedding_size, hidden_size=hidden_size, num_layers=1, batch_first=False, bidirectional=True)

bi_output, bi_hidden = bi_rnn(random_input)

# stagger
forward_output, backward_output = bi_output[:-2, :, :hidden_size], bi_output[2:, :, hidden_size:]
staggered_output = torch.cat((forward_output, backward_output), dim=-1)

linear = nn.Linear(hidden_size * 2, output_size)

# only predict on words
labels = random_input[1:-1]

# for language models, use cross-entropy :)
loss = nn.MSELoss()
output = loss(linear(staggered_output), labels)

I am trying to reimplement the code above found at the bottom of the blog post. I am new to pytorch and nlp, and can’t understand what the input and output to the code is.

Question about the input: I am guessing the input are the few words that are given. Why does one need beginning of sentence and end of sentence tags in this case? Why don’t I see the input being a corpus on which the model is trained like other classic NLP problems? I would like to use the Enron email corpus to train the RNN.

Question about the output: I see the output is a tensor. My understanding is the tensor is a vector, so maybe a word vector in this case. How can you use the tensor to output the words themselves?