Faster way to sum recurrent net hidden states

I’m trying to test a slight modification of multi-layer lstm: I want to sum all the intermediate hidden outputs before applying it to a fc layer. my forward pass looks like this:
def forward(self, input_seq):
out = []
for i in range(input_seq.shape[0]):
_, self.hidden = self.rec_net(input_seq[i].unsqueeze(0), self.hidden)
hidden_sum = self.hidden[0].sum(dim=0).unsqueeze(0)

    return self.fc( # fully connected

it seems to work and converge, but it works very slowly.
Is there a more efficient way to implement it in pytroch?


Hi @guyrose3
If you take a quick look at the LSTM documentation, it seems that passing a (seq_len, batch, input_size) tensor to a nn.LSTM layer gives us a (seq_len, batch, hidden_size * num_directions) output. You can simply call a sum() over this output tensor. Borrowing your variable names, here is what the code would look like:

def forward(self, input_seq):
    output, self.hidden = self.rec_net(input_seq, self.hidden)
    # output is now (seq_len, batch, hidden_size * num_directions)
    output = output.sum(dim=0)
    # output is now (batch, hidden_size * num_directions)
    output = self.fc(output)
    # output is now (batch, fc_output_size)
    return output

Hi @dasguptar,
Thanks for your response!
Maybe I wasn’t clear in my question:
I’m trying to solve many-to-many problem-
I have a [seq_len, batch, input_dim] as input.
My output is [seq_len, batch, ioutput_dim].
my network is a multi-layer lstm, for example - 2 layers.
the current flow:
h1_new= lstm1(input, h1_old,)
h2_new= lstm1(h1_new, h2_old)
out = fc(h2_new)

I want to sum over the num of layers dim, s.t:
h_sum = h1_new + h2_new [seq_len, batch, hidden_dim]
out = fc(h_sum)

It’s like a skip connection in regular feed-forward networks.

What is wrong with this?

h1_new = lstm1(input, h1_old,)
h2_new = lstm1(h1_new, h2_old)
out = fc(h1_new + h2_new)

Hi @guyrose3
Sorry for misunderstanding your question. :disappointed:
Going by your updated example, you seem to be doing the right thing, i.e.

h_sum = h1_new + h2_new
out = fc(h_sum)

This of course means that you cannot use the num_layers keyword argument in nn.LSTM, but will need to make the call individually for each layer. But for your use case, I guess this is as optimal as it can get at the Python level.

training works very slow that way

Is it really that much slower than just using the output of the second layer?

Can you try to profile the forward pass of your model? I am wondering why it will be slower. One possibility is that because you are making K calls to nn.LSTM's forward for K layers, the CUDA kernels might be launched separately everytime, while if you used the num_layers argument, that overhead would not be there. However, I am not sure of this, so maybe a profiling would identify any other bottlenecks.

You could test LSTM with num_layers = 2 like this… EDIT: this won’t work because hidden only contains the last timestep.

lstm = nn.LSTM(.., num_layers=2)
_, (hidden, cell) = lstm(input, (hidden, cell))
# Where hidden is a tuple (hidden_layer1, hidden_layer2)
# and cell is a tuple (cell_layer1, cell_layer2)
out = fc(hidden[0] + hidden[1])

The thing is that hidden and cell state outputs from lstm are only of the recent timestamp.
you will be summing only the last hidden state.
I’m interested in summing these at each timestamp (many to many problem).
That’s why I tried to implement it with a for loop over the time dimention

Good point. I hadn’t realised that.