The hidden states of two reverse sequence on bilstm mismatch

Suppose I input two sequence [1,2,3,4,5] and [5,4,3,2,1] in to biltm. (Imagine the number is word index). In my understanding, the forward hidden_states of the first seq should be eqaul to the backward hidden_states of the second seq and vice versa. However the results seem mismatch.
Here is my snippet:

import torch
import torch.nn as nn

batch_size = 1
seq_len = 5
hidden_size = 3
embed_dim = 5
embed_size = 10
num_layers = 1
num_directions = 2

embedding = nn.Embedding(embed_size, embed_dim)

a = torch.tensor([[1,2,3,4,5]])
b = torch.tensor([[5,4,3,2,1]])
rnn = nn.LSTM(embed_dim, hidden_size, batch_first=True, bidirectional=True, dropout=0)

def my_rnn(input_seq):
    e = embedding(input_seq)
    o, (h, c) = rnn(e)

    o = o.view(seq_len, batch_size, num_directions, hidden_size)
    h = h.view(num_layers, num_directions, batch_size, hidden_size)
    c = c.view(num_layers, num_directions, batch_size, hidden_size)

    return o, h, c

ao, ah, ac = my_rnn(a)
bo, bh, bc = my_rnn(b)

# a forward hidden state = b backward hidden state
print(ah[0, 0], bh[0, 1])

# b forward hidden state = a backward hidden state
print(bh[0, 0], ah[0, 1])

The output is:

tensor([[-0.3088, -0.1794,  0.3160]], grad_fn=<SelectBackward>) tensor([[-0.0156,  0.0029,  0.1224]], grad_fn=<SelectBackward>)
tensor([[-0.1858,  0.0219, -0.0091]], grad_fn=<SelectBackward>) tensor([[-0.2032, -0.0276,  0.2763]], grad_fn=<SelectBackward>)

This post by @vdw explains the outputs, directions etc. and might be helpful.

Maybe this post is also relevant.

1 Like