Encoder/Decoder LSTM model for time series forecasting

I’m trying to implement an encoder-decoder LSTM model for a univariate time-series forecasting problem with multivariate covariates. In other words I have a predictor time series variable y and associated time-series features which will be helpful to predict future values of y. The structure of the encoder-decoder network as I understand and have implemented it are shown in the figure (apologies for the formatting of the key, i couldn’t get the last entry to format on one line correctly!).

Below is a description of a toy example where I want to predict y two steps into the future using the past three timepoints. The general concept being that the encoder LSTM will encode a context variable which can then be used to generate the prediction series sequentially. However I’m getting some tensor dimension mismatches which I don’t understand (all dims are included in the diagram and I am using a batch first approach). In the example I am only assuming a single LSTM layer for simplicity:

Encoder: I pass an input tensor consisting of the predictor variable and covariates at time t-2, t-1, t which will have dimensions (N, L, H_in) where L=3 (in the diagram I have unraveled the LSTM for each time input which is why L=1). The output of the encoder is the hidden state (h_t, c_t) which each have dimensions (1, N, H_out) where H_in is the number of covariate features+1 and H_out is the encoder LSTM hidden size.

Decoder: The previous hidden state of the Encoder model is passed as the initial hidden state of the decoder model as well as the current value of y at time t to predict y_t+1. The predicted value of y_t+1 and the hidden state at t+1 is then used to predict y_t+2.

Theoretically I think this should work, however, I am getting a dimension mismatch between the hidden state output from the Encoder and input of the Decoder. The output of the Encoder hidden state will be (1, N, H_out) where H_out is the hidden size of the encoder LSTM. The input of the Decoder LSTM is (N, 1, 1) as it is just the last known value of the predictor variable. My understanding is that the last dimension of the Hidden state should match the last dim of the decoder input, however this won’t unless the LSTM only has a hidden size of 1 (and pytorch is giving a dimension mismatch error). This will also be an issue when using the decoder to generate the output prediction and the hidden size of the decoder won’t be 1.

Do I have a fundamental misunderstanding of how encoder/decoder networks are used for time series forecasting or is there a step that I am missing. I have read and seen that encoder/decoder networks can be used for time series forecasting but I can’t understand how they get around this issue! A lot of the examples use embedding layers as they are for NLP which I think gets around this issue.

Can you provide a condensed code example that gives the error? Just feed a torch.rand() of appropriate size as input.

I am getting a dimension mismatch between the hidden state output from the Encoder and input of the Decoder.

The hidden state for each LSTM layer should be stored for concurrent timesteps and provided back into the same LSTM they came out of.

Sure here is the most minimal example I can give, I have broken the code into the classes for the encoder/decoder model and the minimal toy example to attempt to train the model on artificial data. Yes the hidden states are fed back into the encoder or decoder LSTM from where they came. The only example where this is not the case is where the final hidden state of the encoder is used as the initial hidden state of the decoder (which is how I understood an encoder/decoder network to work). Also I know in the code below that currently this will not work for a bidirectional LSTM even though it is possible to set this in the class when instantiating the class.

First the model classes:

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class EncoderDecoderLSTM(nn.Module):
    def __init__(self,
                 encoder_model,
                 decoder_model,
                 lr=0.001,
                 n_epochs=100,):
        super(EncoderDecoderLSTM, self).__init__()

        self.lr = lr
        self.n_epochs = n_epochs

        self.encoder_model = encoder_model
        self.decoder_model = decoder_model

        self.encoder_optimizer = optim.Adam(self.encoder_model.parameters(), lr=lr)
        self.decoder_optimizer = optim.Adam(self.decoder_model.parameters(), lr=lr)

        self.criterion = nn.MSELoss()


    def train_epoch(self,
                    dataloader):

        total_loss = 0
        for data in dataloader:
            input_tensor, target_tensor = data

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            # May need to initialise hidden?
            encoder_hidden = self.encoder_model.init_hidden(batch_size=input_tensor.shape[0])
            _, encoder_hidden = self.encoder_model.forward(input_tensor, encoder_hidden)
            decoder_outputs, decoder_hidden, _ = self.decoder_model.forward(
                decoder_input=input_tensor[:, -1, -1].unsqueeze(1).unsqueeze(1),
                encoder_hidden=encoder_hidden)

            loss = self.criterion(decoder_outputs, target_tensor)
            loss.backward()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()

            total_loss += loss.item()

        return total_loss / len(dataloader)

    def train(self,
              dataloader):

        for epoch in range(1, self.n_epochs + 1):
            loss = self.train_epoch(dataloader)


class EncoderLSTM(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_dim,
                 batch_size,
                 n_layers,
                 bidirectional=False,
                 dropout_p=0.1):
        super(EncoderLSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_dim
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout_p = dropout_p
        self.n_layers = n_layers

        self.lstm = nn.LSTM(input_size,
                            hidden_dim,
                            n_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout_p)

    def forward(self, X, hidden):
        print(f'Encoder Before:, input: {X.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output, hidden = self.lstm(X, hidden)
        print(f'Encoder Final:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        return output, hidden

    def init_hidden(self, batch_size):

        hidden = (torch.zeros(self.n_layers, batch_size, self.hidden_size),
                  torch.zeros(self.n_layers, batch_size, self.hidden_size))
        return hidden


class DecoderLSTM(nn.Module):
    def __init__(self, hidden_dim,
                 output_size,
                 batch_size,
                 n_layers,
                 forecasting_horizon,
                 bidirectional=False,
                 dropout_p=0,):
        super(DecoderLSTM, self).__init__()

        self.hidden_size = hidden_dim
        self.output_size = output_size
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout_p = dropout_p
        self.forecasting_horizon = forecasting_horizon

        self.lstm = nn.LSTM(hidden_dim,
                            hidden_dim,
                            n_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout_p)

        self.out = nn.Linear(hidden_dim, output_size)

    def forward(self,
                decoder_input,
                encoder_hidden):

        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(self.forecasting_horizon):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)
            decoder_input = decoder_hidden[0][-1, :, :].unsqueeze(0).permute(1, 0, 2)

        decoder_outputs = torch.cat(decoder_outputs, dim=1)

        return decoder_outputs, decoder_hidden, None  # We return `None` for consistency in the training loop

    def forward_step(self, X, hidden):
        print(f'Decoder Before:, input: {X.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output, hidden = self.lstm(X, hidden)
        print(f'Decoder After:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output = self.out(output)
        print(f'Decoder Final:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        return output, hidden


class SequenceDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.len = x.shape[0]

    def __getitem__(self, idx):
        x = self.x[idx, :, :]
        y = self.y[idx, :]
        return x, y

    def __len__(self):
        return self.len

This is the minimal example to attempt to train the model using linear combinations of sine waves:

t = np.arange(0, 100).reshape((-1, 1))
x1 = np.sin(t)
x2 = 2*np.sin(0.8 * t)
x3 = np.sin(1.5 * t)
x4 = 0.1 * np.sin(0.1 * t)
x5 = np.sin(2 * t)

y = x1 + x2 + x3 + x4 + x5

data = np.hstack((x1, x2, x3, x4, x5, y))

# Chop the data into X and y data using a sliding window
forecast_horizon = 2 #No. of time steps into the future to predict
chunk_size = 3 #No. of previous time steps to use
max_counter = data.shape[0] - chunk_size - forecast_horizon
data_chunked_inputs = np.zeros((max_counter, chunk_size, data.shape[1]))
data_chunked_outputs = np.zeros((max_counter, forecast_horizon, 1))
counter = 0
start_time_input = 0
batch_size = 10

while counter < max_counter:
    # Specify the start and end times for each chunk
    end_time_input = start_time_input + chunk_size
    start_time_output = end_time_input
    end_time_output = start_time_output + forecast_horizon
    # slice chunks
    input_chunk = data[start_time_input:end_time_input, :]
    output_chunk = data[start_time_output:end_time_output, -1].reshape((-1, 1))
    data_chunked_inputs[counter, :, :] = input_chunk
    data_chunked_outputs[counter, :, :] = output_chunk
    start_time_input += 1
    counter += 1

# Create the dataloader
dataset = SequenceDataset(x=data_chunked_inputs, y=data_chunked_outputs)
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

# Set some arbitrary parameters
hidden_dim = 100
n_rnn_layers = 1
dropout = 0
learning_rate = 0.001
n_epochs = 100

# Initialise each of the models
model_encoder = EncoderLSTM(input_size=data.shape[1],
                                    hidden_dim=hidden_dim,
                                    batch_size=batch_size,
                                    n_layers=n_rnn_layers,
                                    bidirectional=False,
                                    dropout_p=dropout)

model_decoder = DecoderLSTM(hidden_dim=hidden_dim,
                            output_size=1,
                            batch_size=batch_size,
                            n_layers=n_rnn_layers,
                            forecasting_horizon=forecast_horizon,
                            bidirectional=False,
                            dropout_p=0)

model_encoder_decoder = EncoderDecoderLSTM(encoder_model=model_encoder,
                                           decoder_model=model_decoder,
                                           lr=learning_rate)

# Attempt to train the model
model_encoder_decoder.train(dataloader=loader)

The error in the forward call of the lstm function of the decoder forward step method on the first iteration it is used:


RuntimeError: input.size(-1) must be equal to input_size. Expected 100, got 1

So I think I figured this out. I just added a linear layer as input to the decoder LSTM which enables me to transform the input to the correct dimensions.

First thing I notice is you’re missing a forward function in the main module, EncoderDecoderLSTM. Perhaps you might want that to just be a trainer class and not be an nn.Module child. Or you can add a forward function. Something like:

    def forward(self, input_tensor):
        encoder_hidden = self.encoder_model.init_hidden(batch_size=input_tensor.shape[0])
        _, encoder_hidden = self.encoder_model(input_tensor, encoder_hidden)
        decoder_outputs, decoder_hidden, _ = self.decoder_model(
            decoder_input=input_tensor[:, -1, -1].unsqueeze(1).unsqueeze(1),
            encoder_hidden=encoder_hidden)
        return decoder_outputs

Second, I noticed that you’re passing in the encoder hidden layer to the decoder. That is likely where you are getting a size mismatch. Instead, you should instantiate its own hidden state. Consider these points:

  1. If your decoder layer is supposed to learn how to remember and filter previous states, how can it accomplish this if it keeps getting the encoder’s state?
  2. If the encoder output and decoder output sizes are different, and you are passing the encoder hidden state to the decoder, you will most certainly get a size mismatch. That is because the hidden size is based on the output size, not the input size. See here:

Thanks for the help, I appreciate it.

In regard to your second point, I thought that this was the whole point of how a seq2seq model worked? Below is an excerpt from the following article:

Encoder - The encoder reads the entire input sequence, one word per timestep, processes it and captures some contextual information about the input sequence into what is known as a context vector or thought vector. This is expected to contain a good summary of the entire input sentence. Every cell in the LSTM layer returns a hidden state (h_i) and cell state (c_i). The last hidden state and cell state are used to initialize the decoder, which is the second component of this architecture.

Decoder - Just like the encoder, the decoder reads the entire target sequence offset by one timestep along with the last hidden state and cell state of the encoder and predicts the next word in the target sequence. We add to the beginning of the target sequence, indicating the start of a sentence and to the end of the sequence, indicating the end of the sentence.

From the way I understood it, the hidden state of the encoder LSTM summarises the input sequence, the decoder network then uses this hidden state to create a meaningful output sequence. Am I misunderstanding?

In the toy example I gave, the both the encoder and decoder have the same output size which is the hidden_size. The decoder output is then put into a linear layer to produce the final output prediction.

Sorry for the delay. For some reason, I did not see the notification of your reply.

Below is a more simplified example that removes the error. Note, I passed the output of the encoder layer to the decoder layer. Alternatively, you could pass in the entire input into the decoder layer, though I’m not sure what you’d hope to accomplish with that.

I also condensed the forward pass, which is then called as self(input_tensor) in the train_epoch function. Note that I changed the def train(self, dataloader): to def trainer(self, dataloader):. This is because an nn.Module already has a train call which is how you set the parameters between train and eval mode.

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EncoderDecoderLSTM(nn.Module):
    def __init__(self,
                 encoder_model,
                 decoder_model,
                 lr=0.001,
                 n_epochs=100,):
        super(EncoderDecoderLSTM, self).__init__()

        self.lr = lr
        self.n_epochs = n_epochs

        self.encoder_model = encoder_model
        self.decoder_model = decoder_model

        self.encoder_optimizer = optim.Adam(self.encoder_model.parameters(), lr=lr)
        self.decoder_optimizer = optim.Adam(self.decoder_model.parameters(), lr=lr)

        self.criterion = nn.MSELoss()


    def train_epoch(self,
                    dataloader):

        total_loss = 0
        for data in dataloader:
            input_tensor, target_tensor = data

            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()
            # May need to initialise hidden?
            decoder_outputs = self(input_tensor)
            loss = self.criterion(decoder_outputs, target_tensor)
            loss.backward()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()

            total_loss += loss.item()

        return total_loss / len(dataloader)

    def forward(self, input_tensor):
        encoder_hidden = self.encoder_model.init_hidden(batch_size=input_tensor.shape[0])

        _, encoder_hidden = self.encoder_model(input_tensor, encoder_hidden)
        decoder_outputs, decoder_hidden, _ = self.decoder_model(
            decoder_input=_,
            encoder_hidden=encoder_hidden)
        return decoder_outputs

    def trainer(self, dataloader):

        for epoch in range(1, self.n_epochs + 1):
            loss = self.train_epoch(dataloader)


class EncoderLSTM(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_dim,
                 batch_size,
                 n_layers,
                 bidirectional=False,
                 dropout_p=0.1):
        super(EncoderLSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_dim
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout_p = dropout_p
        self.n_layers = n_layers

        self.lstm = nn.LSTM(input_size,
                            hidden_dim,
                            n_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout_p)

    def forward(self, X, hidden):
        print(f'Encoder Before:, input: {X.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output, hidden = self.lstm(X, hidden)
        print(f'Encoder Final:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        return output, hidden

    def init_hidden(self, batch_size):

        hidden = (torch.zeros(self.n_layers, batch_size, self.hidden_size),
                  torch.zeros(self.n_layers, batch_size, self.hidden_size))
        return hidden


class DecoderLSTM(nn.Module):
    def __init__(self, hidden_dim,
                 output_size,
                 batch_size,
                 n_layers,
                 forecasting_horizon,
                 bidirectional=False,
                 dropout_p=0,):
        super(DecoderLSTM, self).__init__()

        self.hidden_size = hidden_dim
        self.output_size = output_size
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout_p = dropout_p
        self.forecasting_horizon = forecasting_horizon

        self.lstm = nn.LSTM(hidden_dim,
                            hidden_dim,
                            n_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout_p)

        self.out = nn.Linear(hidden_dim, output_size)

    def forward(self,
                decoder_input,
                encoder_hidden):

        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(self.forecasting_horizon):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)
            decoder_input = decoder_hidden[0][-1, :, :].unsqueeze(0).permute(1, 0, 2)

        decoder_outputs = torch.cat(decoder_outputs, dim=1)

        return decoder_outputs, decoder_hidden, None  # We return `None` for consistency in the training loop

    def forward_step(self, X, hidden):
        print(f'Decoder Before:, input: {X.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output, hidden = self.lstm(X, hidden)
        print(f'Decoder After:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        output = self.out(output)
        print(f'Decoder Final:, output: {output.shape}, h: {hidden[0].shape}, c: {hidden[1].shape}')
        return output, hidden


class SequenceDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.len = x.shape[0]

    def __getitem__(self, idx):
        x = self.x[idx, :, :]
        y = self.y[idx, :]
        return x, y

    def __len__(self):
        return self.len



# Set some arbitrary parameters
hidden_dim = 100
n_rnn_layers = 1
dropout = 0
learning_rate = 0.001
n_epochs = 100
batch_size = 10
forecast_horizon = 2

# Initialise each of the models
model_encoder = EncoderLSTM(input_size=hidden_dim,
                                    hidden_dim=hidden_dim,
                                    batch_size=batch_size,
                                    n_layers=n_rnn_layers,
                                    bidirectional=False,
                                    dropout_p=dropout)

model_decoder = DecoderLSTM(hidden_dim=hidden_dim,
                            output_size=1,
                            batch_size=batch_size,
                            n_layers=n_rnn_layers,
                            forecasting_horizon=forecast_horizon,
                            bidirectional=False,
                            dropout_p=0)

model_encoder_decoder = EncoderDecoderLSTM(encoder_model=model_encoder,
                                           decoder_model=model_decoder,
                                           lr=learning_rate)


dummy_inputs = torch.rand((batch_size, 6, 100))
print(model_encoder_decoder(dummy_inputs))