I am using a bidirectional lstm. By default this returns concatenated forward and backward hidden states as output, but unfortunately, the hidden states h_n and c_n are returned in shape (num_directions, batch_size, hidden dim). In my case am using a seq2seq and my decoder is unidirectional but it has the double of the embedding hidden dim so that I can directly pass the last hidden states and use attention mechanism. So I have to convert the last hidden states from the encoder lstm to a concatenated version just like in the outputs.
I came up with the following but it seems to me a bit strange.
h_n = h_n.permute(1, 0, 2) c_n = c_n.permute(1, 0, 2) h_n = torch.reshape(h_n, (1, h_n.shape, h_n.shape*2)) c_n = torch.reshape(c_n, (1, c_n.shape, c_n.shape*2))
Can someone suggest a better way ?