Let’s say I want to use nn.LSTM as an nn.LSTMCell I initialize these two modules
lstm = nn.LSTM(5, 10, batch_first=True)
cell = nn.LSTMCell(5, 10)
I set their hidden and cell states as equal
cell.weight_ih.data = lstm.weight_ih_l0.data
cell.weight_hh.data = lstm.weight_hh_l0.data
cell.bias_ih.data = lstm.bias_ih_l0.data
cell.bias_hh.data = lstm.bias_hh_l0.data
Initialize the hidden states and an input
h, c = torch.zeros(1, 10), torch.zeros(1, 10)
inp = torch.randn(1, 2, 5)
Run this through lstmcell and lstm
# Cell
h1, c1 = cell(inp[:, 0], (h, c))
h2, c2 = cell(inp[:, 1], (h1, c1))
# LSTM
o_l1 , (h_l1, c_l1) = lstm(inp, (h.unsqueeze(1), c.unsqueeze(1)))
# This all makes sense
assert torch.isclose(h_l1, h2).all()
assert torch.isclose(o_l1[:, 0], h1).all()
But now if I use the LSTM and pass the input with hidden states generated by the LSTMCell.
o_l2, (h_l2, c_l2) = lstm(inp[:, 1].unsqueeze(1), (h1.unsqueeze(1), h1.unsqueeze(1)))
print(h_l2)
# tensor([[[ 0.0331, -0.0655, -0.0764, 0.0269, -0.1706, -0.0454, 0.0769,
# 0.0008, -0.0629, -0.1064]]], grad_fn=<StackBackward0>)
print(h2)
# tensor([[ 0.0454, -0.0672, -0.0577, 0.0335, -0.1670, -0.0498, 0.0721, 0.0143,
# -0.0400, -0.1339]], grad_fn=<MulBackward0>)
h_l2 and h2 appear to be different. Is there something wrong that I am doing? I was expecting it generates the same values.
Here is the colab link to this scenario.
Thank you for the help.