Shouldn't LSTM and LSTMCell produce identical sequence of hidden states when both are fed one timestep at a time?

Hi,

I’m implementing Seq2Seq decoder and the paper says that LSTMCell should be used in a loop for each decoder step. However, I’m not convinced why LSTM itself could not be used. I’d suppose these two procedures should yield the same sequence of hidden states when looped over incrementally. Here’s the code which I thought would give the same results, but it didn’t.

I thought that the source of differences could be because some non-determinism in LSTM layer as is stated in documentation so I also set all possible random seeds, but looks like it’s not the problem.

import random
import torch
import numpy as np
from torch import nn

random.seed(0)
torch.manual_seed(100)
np.random.seed(0)

torch.use_deterministic_algorithms(True)

lstm = nn.LSTM(input_size=5, hidden_size=5, num_layers=1)
lstm_cell = nn.LSTMCell(5, 5)

h0 = torch.Tensor([0.1, 0.2, 0.12, -0.3, 0.1]).unsqueeze(0)
c0 = torch.Tensor([0.1, 0.2, 0.12, -0.3, 0.1]).unsqueeze(0)
x = torch.randn(8, 5)

with torch.no_grad():
    h, c = h0, c0
    for i in range(8):
        _, (h, c) = lstm(x[i].unsqueeze(0), (h, c))
        print(f"{h=} {c=}")

    print("\n")
    h, c = h0, c0
    for i in range(8):
        h, c = lstm_cell(x[i].unsqueeze(0), (h, c))
        print(f"{h=} {c=}")

Hello Miro,
Indeed nn.LSTM with a single layer contains only a single nn.LSTMCell. However, even though you fix the random seed, both modules get initialized with different weights.

If you set the weights of one of the modules to be identical to the weights of the other one, their outputs should become identical as well. Try modifying your code with:

lstm_cell.weight_hh = lstm.weight_hh_l0
lstm_cell.bias_hh = lstm.bias_hh_l0
lstm_cell.weight_ih = lstm.weight_ih_l0
lstm_cell.bias_ih = lstm.bias_ih_l0
1 Like

Okay, looks like you are right. The results are now identical. So it seems that those two layers are equivalent for “single-layer one-step-at-a-time decoding use case”, if we disregard the initialization differences. Thanks.