What is the best way to concatenate hidden states of lstm output?


I am using a bidirectional lstm. By default this returns concatenated forward and backward hidden states as output, but unfortunately, the hidden states h_n and c_n are returned in shape (num_directions, batch_size, hidden dim). In my case am using a seq2seq and my decoder is unidirectional but it has the double of the embedding hidden dim so that I can directly pass the last hidden states and use attention mechanism. So I have to convert the last hidden states from the encoder lstm to a concatenated version just like in the outputs.

I came up with the following but it seems to me a bit strange.

    h_n = h_n.permute(1, 0, 2)
    c_n = c_n.permute(1, 0, 2)
    h_n = torch.reshape(h_n, (1, h_n.shape[0], h_n.shape[2]*2))
    c_n = torch.reshape(c_n, (1, c_n.shape[0], c_n.shape[2]*2))

Can someone suggest a better way ?