PyTorch LSTM Input Confusion

Hi I have a biLSTM with batch_first as True. According to the docs the input for the lstm should be NxLxH. I have realized my code has been working even though I have been providing it with an input of the wrong shape, I’m not sure why it’s not throwing an error here. Could someone explain? and now that I have realized this I’m unsure how to get my [4x768] vector to be the shape [4x512x768]. Here’s an example:

text_emb = embedding.encode(input, show_progress_bar=False, convert_to_tensor=True)
print(text_emb.shape)
lstm = nn.LSTM(input_size=768, hidden_size=20, batch_first=True, bidirectional=True)
output,(ht, ct) = lstm(text_emb)
print(output.shape, ht.shape)

The text_emb.shape is [4, 768] which is the input to the lstm and this still works without throwing errors?

1 Like

Have a look at the source code; this is the important snippet:

is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
    input = input.unsqueeze(batch_dim)

Since your text_emb is only 2 dim, the LSTM thinks it’s a single sequence of length 4 and not batch of sequences. In this case, the LSTM simply uses unsequeeze() to add the missing batch dimensions. Since you use batch_first=True, it’s probably (1, 4, 768) after the unsqueeze().

I would probably prefer that the LSTM would throw an error and “force” the user to ensure the correct input shape :).

3 Likes

YES, exactly! I understand the feature request and that allowing inputs with a missing batch dimension might be convenient for some use cases, but I was already tricked by it a few times while trying to debug shape mismatch issues reported by users. :sweat_smile:

1 Like

I didn’t even know that nn.LSTM is doing this (now), but this post made me curious and I thought this could be only explanation. That’s why I looked at the code to double-check.

Not sure if avoiding a single line of code like x.unsqueeze(0) is worth the potential pitfalls here.

1 Like

Yes, I totally agree, but have to admit I’ve missed the actual feature request on GitHub to raise some concerns.

1 Like

@msabrii As a side note: The docs are actually up to date. The specification for the input distinguishes between batched and unbatched input. It seems the addition was made in version 1.11.

Hi, thank you for your help. Does this mean that I am potentially running an older version of pytorch and that is why I am not getting an error? I have been using google colab so it could be that is the case.

As a side note if someone could explain to me how I can convert my tensor to the required shape for input to this lstm, that would be very helpful. I pass a batch of strings to the embedding layer which returns a tensor of [4, 768]. I know the max sequence length for my inputs is 512 but I’m unsure how to convert the tensor returned by the embedding layer into that.

Thank you all again for your help! :slight_smile:

Edit: I should mention that the reason the embedding layers returns the tensor of that shape is because I am using a sentence embedding so it’s one tensor for the whole sentence. And this is why I am confused as to how I can reshape my input for the LSTM requirement

No, you are not running an older PyTorch release, as newer versions allow “unbatched” inputs.

You could unsqueeze the tensor and work with a sequence length of 1, but I’m unsure why you would like to use an RNN if your use case deals with samples which do not have a temporal dimension.

I you pass a batch of strings, do you mean a sequence of tokens/word?

Usually the input for the embedding layer is already (batch_size, seq_len). Once pushed through the embedding layer, the output would be (batch_size, seq_len, embed_size) where embed_size has to match the input_size of the LSTM.

Where is you max sequence length of 512 reflected in the shape (4, 768)?

Ahh I see, so for some context the only reason I found out I had the wrong input was because I got an error for it when I ran it on a different machine so I assumed that machine had a newer version of pytorch hence the error for the same code that ran fine for me on colab.

Your comment also made me realize I have approached my task wrong, I am trying to classify reddit comments and what I was doing was creating a single tensor for one whole post, when what I need to do is split the singular post into sentences and then get the sentence embeddings and make a tensor with those embeddings for each sentence so the lstm can learn the dependencies. Thank you for your help!

1 Like