Pack padded sequence last hidden state shape

Hi,
It is mentioned in the documentation of an LSTM, that if batch_first = True for pack_padded_sequence input to LSTM (bi-directional), the last hidden state output is also of shape (batch, num_directions, hidden_size). However, the output of the last hidden state appears to be of shape (num_directions, batch, hidden_size), even though batch_first is set to true. This is the code:

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
from torch import nn

sequences = torch.FloatTensor([[1, 2, 0, 0, 0, 0],  # length 2
                               [3, 4, 5, 0, 0, 0],  # length 3
                               [5, 6, 0, 0, 0, 0],  # length 2
                               [8, 9, 10, 11, 12, 0]
                               ])  # length 5
seq_lengths = torch.LongTensor([2, 3, 2, 5])

rnn = nn.LSTM(1, 5, 
              batch_first=True, 
              bidirectional = True)

packed_sequences = pack_padded_sequence(sequences.unsqueeze(2),
                                        lengths=seq_lengths,
                                        batch_first=True,
                                        enforce_sorted=False)  

rnn_output, (hn,cn) = rnn(packed_sequences)
print(hn.shape)   # torch.Size([2, 4, 5]). It should be torch.Size([4, 2, 5])? 

@albanD may you please help?

Hi,

Unfortunately I have no experience with the recurrent modules and the pack padded sequence mechanics.
From what I see, could it be the case that this distinction batch_first is to compare with the sequence dimension that you can have sometimes? and that the bidirectional will always add an extra dimension of size 2 at position 0 on top of whatever you already had?

Thank you for your answer. Most probably this is what happens… However, even if the bidirectional is set to False, the output would still be of shape (1,4,5). Transposing the matrix to be (4,1,5) would solve the problem i believe. I just wanted to make sure i’m doing the right thing. Thanks a lot.

So after asking around, basically, hidden_state is just not influenced by the batch_first argument. So it will have this shape whatever you pass.
This inconsistency is tracked in the following issue.

1 Like