Bidirectional LSTM with different sequence length

Hi,
I am trying to implement a bidirectional LSTM with PPO, which is an on-policy algorithm. Due to latter’s algorithm inherent nature, we usually collect a rollout of experiences although the episode itself has not finished. Hence, we make the (Bi)LSTM stateful along the episode and we reset its hidden states when a new episode is going to be initialized.

The question here remains in the BILSTM behaviour. When I collect experiences I collect them 1 by 1, thus the sequence_length=1. However, when doing an optimization step I pass al the experiences of current rollout altogether, sequence_length=rollout_length. At those scenarios I get the same/expected hidden & cell states but the output of the network is reversed.

I attach some code to explain my issue:

import torch
import torch.nn as nn

input_size = 10  #indicates the format of input data (i.e. output of CNN)
num_layers = 1 #num of layers of the lstm
hidden_size = 1 # num of "neurons"
seq_len = 3 # sequence length (i.e. number of observations in a sequence)
batch_size = 1
bidirectional = True 


#define lstm 
lstm = nn.LSTM(input_size=input_size,
               hidden_size=hidden_size,
               num_layers=num_layers,
               bidirectional=bidirectional)
#hidden_states
h = torch.zeros(num_layers*(1+int(bidirectional)),batch_size,hidden_size)
c = torch.zeros(num_layers*(1+int(bidirectional)),batch_size,hidden_size)

# -----------------------------------------------------------------------------
#monitorization
hidden = []
output = []
obs = torch.randn(seq_len,batch_size,input_size)

# 1.step-by-step forward pass
for i in range(seq_len):
    inp = obs[i].unsqueeze(0)
    out, (h,c) = lstm(inp,(h,c))
    hidden.append(h)
    output.append(out)
    
# 2. all together sequence
h = torch.zeros(num_layers*(1+int(bidirectional)),batch_size,hidden_size)
c = torch.zeros(num_layers*(1+int(bidirectional)),batch_size,hidden_size)
out, (h,c) = lstm(obs,(h,c))

And the output of the code is the following:

hidden
Out[120]: 
[tensor([[[0.0295]],
 
         [[0.3094]]], grad_fn=<StackBackward>),
 tensor([[[-0.3171]],
 
         [[ 0.3956]]], grad_fn=<StackBackward>),
 tensor([[[0.0404]],
 
         [[0.3690]]], grad_fn=<StackBackward>)]

h
Out[121]: 
tensor([[[0.0404]],

        [[0.2603]]], grad_fn=<StackBackward>)

output
Out[122]: 
[tensor([[[0.0295, 0.3094]]], grad_fn=<CatBackward>),
 tensor([[[-0.3171,  0.3956]]], grad_fn=<CatBackward>),
 tensor([[[0.0404, 0.3690]]], grad_fn=<CatBackward>)]

out
Out[123]: 
tensor([[[ 0.0295,  0.2603]],

        [[-0.3171,  0.3946]],

        [[ 0.0404,  0.2417]]], grad_fn=<CatBackward>)

As it could be seen, neither the hidden and output (step-by-step) are equal for the “future/inversed” part of the lstm respect to the h and out attributes (whole sequence) respectively.

My questions are:

  1. I guess that the correct output would be when passing all the samples as a full sequence, this is, out/h attributes, as it could take into account all the scope of the sequence instead of just 1-step look ahead. Am I right?
  2. As we are truncating the episode steps into different rollouts where we then update the network parameters… should we make another forward pass with the previous hidden state in order to get the new hidden state according to our new network parameters? I.e: after passing our 5 sequences across the network with a initial hidden state (h0) we get h1; however, if we now update the network and we pass again the same observations and h0, we would get h1’ instead of h1.
    I have seen that it is referred as stale hidden issue and there is lot of mess about if it has to be updated or not.
  3. As we are truncating the episode steps into different rollouts where we then update the network parameters… we just pass the last/new hidden state (h1) as input for the next rollout straightforward? It makes sense for me in the case of the classic LSTM where we store and carry past information, but I could not see how we could pass the future part truncated as the info which is being passed is messy when truncating.