How to combine forward hidden states and backward hidden states of each time step

Hi everyone,

I am trying to process a sentence word by word with a BiLSTM. My aim is to concatenate the forward and backward LSTM’s hidden states just after each token processed.

For example, Let’s assume our sentence is: TOK1, TOK2, TOK3. Forward LSTM process this sequence by TOK1,TOK2,TOK3 order and produces hf1,hf2,hf3 hidden states. Backward LSTM process this sequence by TOK3,TOK2,TOK1 order and produces hb1,hb2,hb3 hidden states. (hb1 is the hidden stated created just ofter TOK3 is processed).

So, I want to combine:
hb1 with hf3,
hb2 with hf2,
hb3 with hf1

To do that, I did the following:

hidden_dim = 200
embedding_dim = 10
vocab_size = 6
ids = torch.tensor([1,5,4])
bs = 1
embedding = nn.Embedding(vocab_size,embedding_dim)
lstm  = nn.LSTM(embedding_dim, hidden_dim,num_layers=1, bidirectional=True)
embeds = embedding(ids)
output,_ = lstm(embeds)
output = output.view(3, bs, 2, hidden_dim) # 3= sequence_length, 2 = direction number

What I don’t know is that in which order the hidden states of backward lstm stored in this output variable.

For example, output[2,bs,1,:] is the hidden state hb1 or hb3 according to my annotation above ?

EDIT

Before reshaping output variable, it was a vector with (3,bs,2*hidden_dim) dimensions. In this version, are the hidden states stored as I asked or something differently ? I could have answered these questions if I was able to set the forward and backward weights of BiLSTM same.

To be more useful, the first index is always time ascending w.r.t. to the input (it doesn’t make much sense to get a concatenation of the forward step for t=0 and the backward step for t=seq_len-1), so output[t, b] corresponds to the output after input[t, b] has been read.

Best regards

Thomas

Thank you Thomas. Just to make sure, in the case of my example:

hidden_dim = 200
embedding_dim = 10
vocab_size = 6
ids = torch.tensor([1,5,4])
bs = 1
embedding = nn.Embedding(vocab_size,embedding_dim)
lstm  = nn.LSTM(embedding_dim, hidden_dim,num_layers=1, bidirectional=True)
embeds = embedding(ids)
output,_ = lstm(embeds)

output[0,1,:] contains the hidden states generated just after both forward and backward LSTMs consume “1”, right ?

Yes.

If I was in nitpicking mood, I would prefer to say “output” to not confuse it with the LSTM internal states ("_" aka (h,c)), and say you want output[0, 0] and an unsqueeze(1) on the ids to get the batch part right.

But I’m not, so:

Best regards

Thomas