Please help: LSTM input/output dimensions

I am hopelessly lost trying to understand the shape of data coming in and out of an LSTM.

Most attempts to explain the data flow involve using randomly generated data with no real meaning, which is incredibly unhelpful.

Those examples that use real data, like this Udacity notebook on the topic do not explain it well and do not generalize the concept to other kinds/shapes of data beyond strings of text.

This post asks the identical question I have. No attempt at an answer was ever made, except that of the OP, who was unsure if his conclusion was correct.

So, what I’m asking is this: Can someone provide a simple, concrete example of using an LSTM with:

  1. “real” data one dimension (one variable).
  2. “real” data with more than one dimension (multiple variables).

I would greatly appreciate it!

I’ve created a minimal example to feed two sentences into an LSTM. I hope that helps. Let me know if anything is unclear.

4 Likes

@vdw,

Thank you for putting in the time to create those examples. Unfortunately, the use of the “embedded” layer has made things even more confusing for me. I truly am lost, and each additional example I see, including that one adds even more confusion.

I feel that part of the problem is that my questions are too general to be understood. This is my fault. Therefore, I’m making a final attempt: Below is some code from my attempt to work alongside this tutorial.

I have taken a page from your book and added lots of print statements for various layers outputs. Below the outputs I have listed specific questions that I have.

class flightLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        self.hidden = None
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
       
        print(f'input original shape: {x.shape}')
        
        # format input to shape (batch_size, seq_len, input_size)
        x = x.view(1, len(x), 1) 
        
        print(f'input reshaped for LSTM: {x.shape} \n')
        
        out, self.hidden = self.lstm(x, self.hidden)
        
        print(f'output shape from LSTM layer: {out.shape}')
        print(f'output = \n {out}')
        print()
        
        # flatten output to (batch_size, hidden_size)
        out = out.view(-1, self.hidden_size)
       
        print(f'out reshaped for FC layer: {out.shape} \n')
        out = self.fc(out)
        
        print(f'Fully connected output shape: {out.shape}')
        print(f'output = \n {out}')
        
        return out

input_size = 1
hidden_size = 3
num_layers = 1
output_size = 1

model = flightLSTM(input_size, hidden_size, num_layers, output_size)

# loads a Pandas dataframe
data = sns.load_dataset("flights")
data.passengers = data.passengers.astype(np.float32)

# reserve the last 12 months as test data. 
train_data = data.passengers[:-12].to_numpy()
test_data  = data.passengers[-12:].to_numpy()

# scale and convert training data to tensor
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_train_data = scaler.fit_transform(train_data.reshape(-1, 1))
scaled_train_data = torch.tensor(scaled_train_data, dtype=torch.float32)

# get target data from known sequence
def get_targets(data, window):
    sequences = []
    L = len(data)
    for i in range(L - window):
        sequence = data[i:i + window]
        label = data[i + window: i + window + 1]
        sequences.append((sequence, label))
        
    return sequences

seq_len  = 12
seq_data = get_targets(scaled_train_data, 12)

Okay, here are the results of a forward pass through the network:


model(seq_data[0][0])

>>>
input original shape: torch.Size([12, 1])
input reshaped for LSTM: torch.Size([1, 12, 1]) 

output shape from LSTM layer: torch.Size([1, 12, 3])
output = 
 tensor([[[ 0.1085,  0.0436, -0.0182],
         [ 0.1245,  0.0547, -0.0299],
         [ 0.1277,  0.0552, -0.0445],
         [ 0.1289,  0.0555, -0.0494],
         [ 0.1298,  0.0570, -0.0462],
         [ 0.1298,  0.0550, -0.0541],
         [ 0.1292,  0.0518, -0.0669],
         [ 0.1295,  0.0504, -0.0733],
         [ 0.1304,  0.0520, -0.0682],
         [ 0.1311,  0.0558, -0.0542],
         [ 0.1313,  0.0602, -0.0369],
         [ 0.1306,  0.0595, -0.0377]]], grad_fn=<TransposeBackward0>)

out reshaped for FC layer: torch.Size([12, 3]) 

Fully connected output shape: torch.Size([12, 1])
output = 
 tensor([[0.2949],
        [0.2927],
        [0.2865],
        [0.2844],
        [0.2862],
        [0.2822],
        [0.2758],
        [0.2726],
        [0.2753],
        [0.2824],
        [0.2911],
        [0.2906]], grad_fn=<AddmmBackward>)

Okay, so my questions about this specific code:

  1. I’m providing one sequence of 12 data points. That means my batch size is one, correct?
  2. If I wanted to provide the network two sequences at a time, I’d need to shape my input to (2, 12, 1). Does this mean a batch is defined as a set of column vectors, where each sequence is a column?
  3. If I wanted to have a batches with two sequences, each one containing two variables per point in time, I’d have to change my shape to (2, 12, 2). Is that correct?
  4. Going back to the data output from the forward pass above. The LSTM layer outputs a tensor of size (1, 12, 3). It’s basically a tensor containing 3 columns. What does each column represent? What does each element of a given column represent?
  5. I need to “flatten” the output tensor before I can feed it to my fully connected layer. Have I done so correctly?
  6. The output of the fully connected layer is essentially a column vector. I specified that my output be of size 1, but I got back 12. I think this means that each element of my 12 point sequence gets run through the layers individually. That makes me think that each element of this output is simply the predicted “next step” for that point in the sequence, but I don’t know. What does each element of this column vector mean?
  7. If I am correct about #6, why does the author of this code only have a singleton label for each 12 point sequence? See the function “get_targets”.
1 Like

OK, so you’re doing basic time series analysis. In this case, input_size=1, of course.

Before replying to your specific questions, just some general comments:

  • Try not to think too much of row or column vectors. Once you have tensors with more than 2 dimensions (i.e., a matrix), the notions of rows, columns, etc. get quickly confusing. Think about of dimensions and their order.
  • Network layers only throw an error if the shape of the input is not correct. However, just because the input shape is correct, doesn’t mean that the input itself is correct. Please have a look at a previous post of mine where I outline some pitfalls of view() and reshape(). For example, the line x = x.view(1, len(x), 1) should be fine since you only have 1 sequence in your batch and input_size=1.
  1. I’m providing one sequence of 12 data points. That means my batch size is one, correct?

Yes! It’s only a bit unintuitive why the shape of the original x is (12, 1), since the batch size usually comes first. But I assume that’s just the result of your get_targets() method.

  1. If I wanted to provide the network two sequences at a time, I’d need to shape my input to (2, 12, 1). Does this mean a batch is defined as a set of column vectors, where each sequence is a column?

Yes! This needs to be the resulting shape. But again, be careful how you get to this shape without messing up your batch tensor; see link above.

  1. If I wanted to have a batches with two sequences, each one containing two variables per point in time, I’d have to change my shape to (2, 12, 2). Is that correct?

Yes! (same comment as above). I would only change the phrasing a bit: the batch contains 2 sequences, each sequence contains 12 data points, and each data point has 2 features (or, is represented by a 2-dim (feature) vector). The term “variable” does fit here.

  1. Going back to the data output from the forward pass above. The LSTM layer outputs a tensor of size (1, 12, 3). It’s basically a tensor containing 3 columns. What does each column represent? What does each element of a given column represent?

Don’t think for rows or columns. The output of an LSTM gives you the hidden states for each data point in a sequence, for all sequences in a batch. You only have 1 sequence, it comes with 12 data points, each data point has 3 features (since this is the size of the LSTM layer). Maybe this image helps a bit:

In your case, you have h_1 to h_12 since you have 12 data points, and each h_i is a vector with 3 features.

  1. I need to “flatten” the output tensor before I can feed it to my fully connected layer. Have I done so correctly?

No! The linear layer expects an input shape of (batch_size, "something"). Since your batch size is 1, out after flattening need to be of shape (1, "something"), but you have (12, "something"). Note that self.fc doesn’t care, it just sees a batch of size 12 and does process it. In your simple case, a quick fix would be out = out.view(1, -1)

  1. The output of the fully connected layer is essentially a column vector. I specified that my output be of size 1, but I got back 12. I think this means that each element of my 12 point sequence gets run through the layers individually. That makes me think that each element of this output is simply the predicted “next step” for that point in the sequence, but I don’t know. What does each element of this column vector mean?

Once you fixed the flattening, self.fc will return an output of size 1. Through the incorrect flattening to created a batch of size 12.

  1. If I am correct about #6, why does the author of this code only have a singleton label for each 12 point sequence? See the function “get_targets”.

Same issue: you’re flattening is off.

2 Likes

Thank you for your help. I think I’m starting to get it. Sorry if I’m difficult to get through to!

Yes, the time-series data made more sense to me as a beginner than the text data because I could think in terms of ‘samples’ and ‘variables’ :slight_smile: I’ll try tackling text data next.

Follow up 1: Is there any reason I couldn’t have input_size > 1? What if my time series data has more than just a passengers variable? For instance each point in time might also have an average weather, or some such thing. Can I not use multiple variables to predict a single time series outcome?

follow up 2: Okay, so LSTM layer outputs/hidden aren’t worth trying to interpret? All of that data is transient and unit-less, not meant for human consumption? At least I understand the shape of it now!

follow up 3: I’m wondering why .view(-1, self.hidden_size) was used here, in this Udacity notebook in their .forward() method?

I thought that .forward(-1, hidden_size) would give me exactly (batch_size, hidden_size), and it appears to have worked. Perhaps a lucky accident:

I didn’t understand the stackabuse way which was to do out.view(len(input), -1)

Yes, the time-series data made more sense to me as a beginner than the text data because I could think in terms of ‘samples’ and ‘variables’ :slight_smile: I’ll try tackling text data next.

A sentence is nothing but a time series of words. The LSTM doesn’t care :slight_smile:

Follow up 1: Is there any reason I couldn’t have input_size > 1 ? What if my time series data has more than just a passengers variable? For instance each point in time might also have an average weather, or some such thing. Can I not use multiple variables to predict a single time series outcome?

You can of course have multiple features for each data point, i.e., input_size > 1.

follow up 2: Okay, so LSTM layer outputs/hidden aren’t worth trying to interpret? All of that data is transient and unit-less, not meant for human consumption? At least I understand the shape of it now!

Exactly.

follow up 3: I’m wondering why .view(-1, self.hidden_size) was used here, in this Udacity notebook in their .forward() method?

I’ve looked at the notebook, and I would argue the setup is a bit different. If I understand your code correctly, your input are sequence of length 12 and the target is of length 1 (basically your one label). In the notebook, the target is also a sequence and of the same length as the input sequence, only shifted by one time step. That means with one example, say input=[1,2,3,4] and target=[2,3,4,5], they simulate a batch of 4 examples:

  • [1] => 2
  • [1,2] => 3
  • [1,2,3] => 4
  • [1,2,3,4] => 5

So there .view(-1, self.hidden_size) makes sense.

Can you post the shape of label as well as the code for the training including the part where you calculate the loss?

1 Like

You are correct, the label shape is a 1-d array/tensor.

Here is the training code:

def train(model, epochs, train_set, lr=0.001, print_every=10):
    
    criterion = nn.MSELoss()
    opt = optim.Adam(model.parameters(), lr=lr)
    
    for e in range(epochs):
        hidden = None
        for x, y in train_set:
            
            opt.zero_grad()
            
            out, hidden = model(x, hidden)
        
            loss = criterion(out, y.squeeze(0))
            loss.backward()
            opt.step()
            
            hidden = tuple([h.data for h in hidden])
            
        if e % print_every == 0:
            print(f'Epoch {e} Training Loss: {loss}')

And here is my full notebook. At this point, I think I’ve piecemeal copied and pasted the whole thing into this thread. Maybe it will be helpful to have it all in one place.

Thank you again for your ongoing help! I hope to finally understand this stuff and stop bugging you :sweat_smile:

I think I found the reason show it works despite the wrong flattening. You’re last two lines in your forward() method are:

def forward(self, x, hs):
    ...
    out = self.fc(out)
    return out[-1], hs

The out[-1] resolves the “artificial” batch of 12 values so you have only 1 output value. If you would have done the flattening correctly the [-1] wouldn’t be needed. The problem is that you sticked too closely to the Udacity tutorial although your setup is different. So you had to make tweaks like the out[-1] to fix it.

It’s a bit like in math exam where you make a mistake at the beginning and an mistake at the end, and both mistakes cancel each other out.

2 Likes

Okay, lesson learned. Shape of input for FC layer is always (batch_size, “something”), which is just like it was for CNNs. Except calculating that “something” was easy back then. It was just the num_channels * height * width.

Exactly! Here it’s a more tricky since you have the time step dimension.

In fact, with your set up I would not have used the LSTM output but the last hidden state:

def forward(self, x, hs):
    
    # format input to shape (batch_size, seq_len, input_size) 
    x = x.view(1, len(x), 1) 
    
    out, hs = self.lstm(x, hs)
    
    (h, c) = hs # h.shape = (num_layers * num_directions, batch, hidden_size) = (1, 1, 100)

    last_h = h[-1] # Here the [-1] is not because of the sequence but because of the num_layers
    # last_h.shape = (batch, hidden_size) = (1, 100) <== What you want

    # flatten LSTM input for fully connected layer: (batch_size, hidden_size)
    #out = out.view(-1, self.hidden_size) # No longer needed
    
    out = self.fc(last_h) # out.shape = (batch_size, output_size) = (1, 1)
    
    #return out[-1], hs
    return out, hs

You might want to give this a try.

2 Likes

So this solution is a way of getting the “many to one” output, correct?

And this works because the final element of the final hidden state is equal to the final element of the output tensor? That diagram you sent me was helpful.

That would give the correct shape of (batch_size, hidden_size) = (1, 100). I think I get it!

Okay, but…

What if I wanted to do a sequence input to sequence output? More like in the NLP Udacity example.

In this case, my LSTM output would have shape (1, 12, 100) if batch_first=True.

out, hs = self.lstm(x, hs)
out = out.view(1, -1) 

# gives shape (batch_size, "something") = (1, 1200) 

So that would not work, because the fully connected layer needs input shaped like (batch_size, hidden_size)

Instead, if I do this:

out, hs = self.lstm(x, hs)
out.reshape(-1, self.hidden_size)

# gives shape (12, 100), which works with FC layer.

But would this be correct, seeing as how we are artificially making a new batch_size = 12, just for the final FC layer?

Or, should I redefine my nn.Linear() layer:

self.fc = nn.Linear(hidden_size * seq_length, output_size)?

Where output_size = 12 if I want a sequence output.

1 Like

So this solution is a way of getting the “many to one” output, correct?

Precisely! That is the basic approach that fits your setup where each example is a sequence of 12 data points as input and one 1 target value. Yes, that diagram made things clearer to me as well :).

What if I wanted to do a sequence input to sequence output?

Sure, you can also copy the Udacity example basically 1:1, you only need to change your data. For this setup, both inputs and targets are sequences of lengths 12, with the target sequences shifted by one data point. In this case,

reshape(-1, self.hidden_size)

with a resulting output shape of (12, hidden_size) is the correct way to go, since you need/want 12 output values, one for each of element in your target sequence.

Or, should I redefine my nn.Linear() layer: self.fc = nn.Linear(hidden_size * seq_length, output_size)? where output_size = 12 if I want a sequence output.

Hm, that’s a good question. Technically you can set up the model like this, I’m just not sure if this “semantically” the right way to do it. Something feels a bit off but I cannot put my finger one it right now :). I would actually be quite curious how the results would be compared to the more established ways.

1 Like

Great!

What is the functional difference between “many-to-one” vs a “one-to-one” approaches?

I can see that you would need to “prime” a “many-to-one” model with some past data to get the prediction at the next point. Whereas with a “one-to-one” model, you only need a single past data point. I can’t tell when one or the other would be more beneficial?

What the Udacity tutorial is doing is not a “one-to-one” but more like a “many-to-many” approach. Note that the value of the k-th ouput depends on all the (k-1) inputs. Or in other words, the k-th output depends on the hidden state at time k which in turn depends on the (k-1) hidden states before.

In fact, you can directly map the approach used in the Udacity tutorial to the basic “many-to-one” approach. Instead of having one (input => target) example such as

  • [1, 2, 3, 4, 5, 6] => [2, 3, 4, 5, 6, 7]

you can convert this into 6 “many-to-one” examples

  • [1] => [2]
  • [1, 2] => [3]
  • [1, 2, 3] => [4]
  • [1, 2, 3, 4] => [5]
  • [1, 2, 3, 4, 5] => [6]
  • [1, 2, 3, 4, 5, 6] => [7]

From a learning/training perspective, that’s exactly the same. That’s what the reshape(-1, self.hidden_size) in the Udacity tutorial is doing: “internally” creating this artificial batch of, in this case, size 6.

2 Likes

I know now why that approach does not seem suitable to me. Here, all output values depend on all input values. That also means that the first output value would depend on the last input value. This means the network can look into the future.

You can still implement it, of course, but it doesn’t seem to fit your prediction task.

1 Like