Using the last hidden states of an LSTM? How to reshape?

I have (for illustration) an LSTM(insize, hiddensize, num_layers=3, bidirectional=True, batch_first=True) and I want to use the last hidden state of each instance in a batch as an input to a Linear layer.

The Linear layer requires a tensor that has the batch instances in the first dimension, but the LSTM returns the last hidden state as shape (num_layers*isbidirectional), batchsize, hiddensize, (where isbidirectional is 2 if bidirectional, otherwise 1) even if batch_first=True

Maybe I am missing something here, but would it not be much more convenient if the hidden state would be of shape batchsize, (num_layers*isbidirectional), hiddensize?

What is the best way to create the necessary tensor that can be used as an input to the linear layer so that the batches are in the first dimension?

One thing I want to achieve is to just create a 2-dimensional tensor of shape batchsize, (num_layersisbidirectionalhiddensize) so that the 2nd dimensions has all the hidden vectors that belong to each element in the batch dimension concatenated. How would one do this ideally?

Maybe use

h_n.transpose(0,1),contiguous(),view(batchsize, -1)

The contiguous() call is needed because transpose() copies data, but I think this cannot be done without copying data, or does anyone else know a better solution?

3 Likes

I’m replying because I run into the same problem and I can’t really understand that one would be forced to use things like h_n.transpose(0,1),contiguous(),view(batchsize, -1) or input_to_cl = out.permute(1, 0, 2).view(-1, self.hidden_dim*len(x))` (suggested in another post) which are quite obscure to me. It also sounds like super cumbersome for a framework such as pytorch, so I did a test :

test = torch.tensor([[[1,2,3]],[[1,2,3]]]).type(torch.FloatTensor)
#tensor of size ([2,1,6]), 2 beeing batch dimension.
lstm = nn.LSTM(3,6,batch_first = True)
lstm(test)[0].size() 

outputs torch.Size([2, 1, 6]), which means that using batch_first = True in the input of the LSTM instance makes it return an output also with batch first, which solves the issue.