Difficulty with LSTMs for Text Generation

I am currently trying quote generation (character level) with LSTMs using Pytorch. I am currently facing some issues understanding exactly how the hidden state is implemented in Pytorch.

Some details:

I have a list of quotes from a character in a TV series. I am converting those to a sequence of integers with each character corresponding to a certain integer by using a dictionary char2idx. I also have the inverse of this idx2char where the mapping is reversed.

After that, I am using a sliding window, say of size window_size, and a step of size step to prepare the data.

As an example, let’s say the sequence is [1, 2, 3, 4, 5, 0] where 0 stands for the EOS character. Then using window_size = 3 and step = 2, I get the sequence for x and y as:

x1 = [1, 2, 3], y1 = [2, 3, 4]
x2 = [3, 4, 5], y1 = [4, 5, 0]

x = [x1, x2], y = [y1, y2]

The next step is to train the model. I have attached the code I am using to train the model.

NOTE: I am not passing hidden states from one batch to the other as the ith sequence of the (j+1)th batch is probably not the next step to the ith sequence from the jth batch. (This is why I am using a sliding window to help the model remember). Is there a better way to do this?

My main question occurs during testing time. There are two methods by which I am testing.

Method 1:
I take the initial seed string, pass it into the model and get the next character as the prediction. Now, I add that to the starting string and pass this whole sequence into the model, without passing the hidden state. That is, I input the whole sequence to the model, with the LSTM having the initial hidden state as 0, get the output, append the output to the sequence and repeat till I encounter the EOS character.

Method 2:
I take the initial seed string, pass it into the model and get the next character as the prediction. Now, I just pass the character and the previous hidden state as the next input and continue doing so until an EOS character is encountered.

Question

  1. According to my current understanding, the outputs of both methods should be the same because the same thing should be happening in both.
  2. What’s actually happening is that both methods are giving completely different results. Why is this happening?
  3. The second one gets stuck in an infinite loop for most inputs (e.g. it gives “back to back to back to …”) and on some inputs, the first one also gets stuck. How to prevent and avoid this?
  4. Is this related in some way to the training?

I have tried multiple different ways (using bidirectional LSTMs, using one hot encoding (instead of embedding), changing the batch sizes, not using a sliding window approach (using padding and feeding the whole quote at once).

I cannot figure out how to solve this issue. Any help would be greatly appreciated.

CODE

Code for the Model Class:

class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, dropout=0.15):
        super(RNN, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.dense1 = nn.Linear(hidden_size, hidden_size*4)
        self.dense2 = nn.Linear(hidden_size*4, hidden_size*2)
        self.dense3 = nn.Linear(hidden_size*2, vocab_size)
        self.drop = nn.Dropout(dropout)
        
    def forward(self, X, h=None, c=None):
        if h is None:
            h, c = self.init_hidden(X.size(0))
        out = self.embedding(X)
        out, (h, c) = self.lstm(out, (h, c))
        out = self.drop(out)
        out = self.dense1(out.reshape(-1, self.hidden_size)) # Reshaping it into (batch_size*seq_len, hidden_size)
        out = self.dense2(out)
        out = self.dense3(out)
        return out, h, c
        
    def init_hidden(self, batch_size):
        num_l = self.num_layers
        hidden = torch.zeros(num_l, batch_size, self.hidden_size).to(DEVICE)
        cell = torch.zeros(num_l, batch_size, self.hidden_size).to(DEVICE)
        return hidden, cell

Code for training:

rnn = RNN(VOCAB_SIZE, HIDDEN_SIZE, NUM_LAYERS).to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

rnn.train()
history = {}
best_loss = 100

for epoch in range(EPOCHS): #EPOCH LOOP
    counter = 0
    epoch_loss = 0
    
    for x, y in train_loader: #BATCH LOOP
        optimizer.zero_grad()
        counter += 1

        o, h, c = rnn(x)
        loss = criterion(o, y.reshape(-1))   
        epoch_loss += loss.item()
        
        loss.backward()
        nn.utils.clip_grad_norm_(rnn.parameters(), 5) # Clipping Gradients
        optimizer.step()

        if counter%print_every == 0:
            print(f"[INFO] EPOCH: {epoch+1}, BATCH: {counter}, TRAINING LOSS: {loss.item()}")
    
    epoch_loss = epoch_loss/counter       
    history["train_loss"] = history.get("train_loss", []) + [epoch_loss]
    print(f"\nEPOCH: {epoch+1} COMPLETED!\nTRAINING LOSS: {epoch_loss}\n")     

Method 1 Code:

with torch.no_grad():
    w = None
    start_str = "Hey, "
    x1 = quote2seq(start_str)[:-1]

    while w != EOS_TOKEN:
        x1 = torch.tensor(x1, device=DEVICE).unsqueeze(0)
        o1, h1, c1 = rnn(x1)
        p1 = F.softmax(o1, dim=1).detach()
        q1 = np.argmax(p1.cpu(), axis=1)[-1].item()
        w = idx2char[q1]
        start_str += w
        x1 = x1.tolist()[0]+ [q1]
    
quote = start_str.replace("<EOS>", "")
quote

Method 2 Code:

with torch.no_grad():
    w = None
    start_str = "Are we back"
    x1 = quote2seq(start_str)[:-1]
    h1, c1 = rnn.init_hidden(1)

    while w != EOS_TOKEN:
        x1 = torch.tensor(x1, device=DEVICE).unsqueeze(0)
        h1, c1 = h1.data, c1.data
        o1, h1, c1 = rnn(x1, h1, c1)
        p1 = F.softmax(o1, dim=1).detach()
        q1 = np.argmax(p1.cpu(), axis=1)[-1].item()
        w = idx2char[q1]
        start_str += w
        x1 = [q1]
    
quote = start_str.replace("<EOS>", "")
quote