Interpreting LSTM output for 2D input

I am having trouble understanding how nn.LSTM deals with a 2D input. For the code below:

lstm = nn.LSTM(10, 5)
v1 = Variable(torch.randn(13,10))
v2 = Variable(torch.randn(4,13,10))
o1 = lstm(v1)
o2 = lstm(v2)
print(o1[0].size(), o2[0].size())

Output is:

torch.Size([13, 10, 5]) torch.Size([4, 13, 5])

I understand what’s happening in the case of v2. However, v1 is a 2D vector and its output is a 3D vector. What is nn.LSTM doing with v1? How do I interpret the output?

Check the docs http://pytorch.org/docs/0.3.1/nn.html#lstm to see the expected format of the tensors.

That is (obviously) where I looked first. It doesn’t mention anything about how it deals with tensors of dimension two. I expected a LSTM to treat a 2D tensor much like a PackedSequence but it doesn’t.

The documentation also (implicitly) says that the input should have three dimensions: input (seq_len, batch, input_size). But it doesn’t fail or warn when dealing with a 2D input.

Also, what confused me is that an input with 2 dimensions has an output with 3 dimensions.

I tried your code and PyTorch 0.3.1 complained that the LSTM input didn’t have 3 dimensions.

It looks like the lstm took your input of shape (13,10) and unsqueezed it to (13,10,1) and then broadcasted it to (13,10,10). That means that it was interpreting the 2D input as 10 sequences of length 13 containing 10 features at each timestep BUT those 10 features all contained identical values. This makes no sense, so the output makes no sense either.

1 Like

well saying to look at the docs was not meant to be mean (too blunt maybe, I apologize).

I agree with you that it should complain about the dimension of the input not being 3!
In the code (torch/nn/modules/rnn.py) the batch size is from input.size(1) (as it is not packed), but when checking for the input size in check_forward_args, as the batch size is given, it assumes
it is packed and then it is happy with an input of dimension 2!

I did not check further, but I guess broadcast magic makes it work during forward, as you matched its dim1 to the expected number of features.

EDIT: it does not complain in version ‘0.3.0.post4’ (time to update…)

No worries.

I have the same version ‘0.3.0.post4’ which I will update now.