How to combine the output states of two LSTM

Ho there,

I worked in video classification project and I am trying to input my video frames to two different CNN models (alexnet and VGG) to extract frame-level-features from each model then aggregates the output of each model into two different LSTM. My question is how to combine the output states of two LSTM to fed to fully connected layers for classification?

Thanks in advance :slight_smile:

I don’t know if there is a standard method to do this in the literature, but I can think of 2 approaches to the problem.
Supposing you are interested only in the last hidden state of both LSTMs you can:

  1. Concatenate LSTM’s hidden states in one hidden state and feed it to your MLP.
  2. Sum or average the LSTM’s hidden states: this is done for example when you use a bidirectional LSTM to merge the forward and the backward features. Feed the result to your MLP.
1 Like

thank you fro reply, I have two model each model connect to LSTM. in each LSTM, I concatenate LSTM’s hidden states in one hidden state, so, I have two states output from two different LSTM, How can I make fusion between these states?

Let’s say you have a hidden state of size (1, N, H) where N is the batch size and H is the number of hidden units.The 1 is there because I assumed you have 1 layer and no bidirectionality.
So in your case you have 2 of these hidden states, one for each LSTM, and they may differ only in the value of H. Call H1 the dimensionality of the first LSTM and H2 the dimensionality of the second LSTM.

You can obtain a single stacked hidden state of size (1, N, H1+H2) with this code:

stacked_hidden = torch.cat( (hidden1, hidden2), dim=2)

Instead, if you want to sum or average the hidden states you have to set H1 = H2 and then you can simply use the usual sum or average operations.

In all these cases you can squeeze your result to remove the zero-th dimension with value 1 and feed the tensor to the fully connected network. By the way this network should have input_size = H1+H2 (if you use concatenation) or H (if you use sum or average).