Predicting future values with LSTM

I’m currently working on building an LSTM model to forecast time-series data using PyTorch. I used lag features to pass the previous n steps as inputs to train the network. I split the data into three sets, i.e., train-validation-test split, and used the first two to train the model. My validation function takes the data from the validation data set and calculates the predicted valued by passing it to the LSTM model using DataLoaders and TensorDataset classes. Initially, I’ve got pretty good results with R2 values in the region of 0.85-0.95.

However, I have an uneasy feeling about whether this validation function is also suitable for testing my model’s performance. Because the function now takes the actual X values, i.e., time-lag features, from the DataLoader to predict y^ values, i.e., predicted target values, instead of using the predicted y^ values as features in the next prediction. This situation seems far from reality where the model has no clue of the real values of the previous time steps, especially if you forecast time-series data for longer time periods, say 3-6 months.

I’m currently a bit puzzled about how to tackle this issue and define a function to predict future values relying on the model’s values rather than the actual values in the test set. I have the following function predict, which makes a one-step prediction, but I haven’t really figured out how to predict the whole test dataset using DataLoader.

    def predict(self, x):
        # convert row to data
        x = x.to(device)
        # make prediction
        yhat = self.model(x)
        # retrieve numpy array
        yhat = yhat.to(device).detach().numpy()
        return yhat

You can find how I split and load my datasets, my constructor for the LSTM model, and the validation function below. If you need more information, please do not hesitate to reach out to me.

Splitting and Loading Datasets

def create_tensor_datasets(X_train_arr, X_val_arr, X_test_arr, y_train_arr, y_val_arr, y_test_arr):
    train_features = torch.Tensor(X_train_arr)
    train_targets = torch.Tensor(y_train_arr)
    val_features = torch.Tensor(X_val_arr)
    val_targets = torch.Tensor(y_val_arr)
    test_features = torch.Tensor(X_test_arr)
    test_targets = torch.Tensor(y_test_arr)

    train = TensorDataset(train_features, train_targets)
    val = TensorDataset(val_features, val_targets)
    test = TensorDataset(test_features, test_targets)

    return train, val, test

def load_tensor_datasets(train, val, test, batch_size=64, shuffle=False, drop_last=True):
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    return train_loader, val_loader, test_loader

Class LSTM

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout_prob):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(
            input_dim, hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob
        )

        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, future=False):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out = out[:, -1, :]
        out = self.fc(out)

        return out

Validation (defined within a trainer class)

    def validation(self, val_loader, batch_size, n_features):

        with torch.no_grad():
            predictions = []
            values = []
            for x_val, y_val in val_loader:
                x_val = x_val.view([batch_size, -1, n_features]).to(device)
                y_val = y_val.to(device)
                self.model.eval()
                yhat = self.model(x_val)
                predictions.append(yhat.cpu().detach().numpy())
                values.append(y_val.cpu().detach().numpy())

        return predictions, values
1 Like

I’ve finally found a way to forecast values based on predicted values from the earlier observations. As expected, the predictions were rather accurate in the short-term, slightly becoming worse in the long term. It is not so surprising that the future predictions digress over time, as they no longer depend on the actual values. Reflecting on my results and the discussions I had on the topic, here are my take-aways:

  • In real-life cases, the real values can be retrieved and fed into the model at each step of the prediction -be it weekly, daily, or hourly- so that the next step can be predicted with the actual values from the previous step. So, testing the performance based on the actual values from the test set may somewhat reflect the real performance of the model that is maintained regularly.

  • However, for predicting future values in the long term, forecasting, if you will, you need to make either multiple one-step predictions or multi-step predictions that span over the time period you wish to forecast.

  • Making multiple one-step predictions based on the values predicted the model yields plausible results in the short term. As the forecasting period increases, the predictions become less accurate and therefore less fit for the purpose of forecasting.

  • To make multiple one-step predictions and update the input after each prediction, we have to work our way through the dataset one by one, as if we are going through a for-loop over the test set. Not surprisingly, this makes us lose all the computational advantages that matrix operations and mini-batch training provide us.

  • An alternative could be predicting sequences of values, instead of predicting the next value only, say using RNNs with multi-dimensional output with many-to-many or seq-to-seq structure. They are likely to be more difficult to train and less flexible to make predictions for different time periods. An encoder-decoder structure may prove useful for solving this, though I have not implemented it by myself.

You can find the code for my function that forecasts the next n_steps based on the last row of the dataset, namely X (time-lag features) and y (target value). To iterate over each row in my dataset, I would set batch_size to 1 and n_features to the number of lagged observations.

    def forecast(self, X, y, batch_size=1, n_features=1, n_steps=100):
        predictions = []
        X = torch.roll(X, shifts=1, dims=2)
        X[..., -1, 0] = y.item(0)
        with torch.no_grad():
            self.model.eval()
            for _ in range(n_steps):
                X = X.view([batch_size, -1, n_features]).to(device)
                yhat = self.model(X)
                yhat = yhat.to(device).detach().numpy()
                X = torch.roll(X, shifts=1, dims=2)
                X[..., -1, 0] = yhat.item(0)
                predictions.append(yhat)

        return predictions

The following line shifts values in the second dimension of the tensor by one so that a tensor [[[x1, x2, x3, ... , xn ]]] becomes [[[xn, x1, x2, ... , x(n-1)]]].

X = torch.roll(X, shifts=1, dims=2)

And, the line below selects the first element from the last dimension of the 3d tensor and sets that item to the predicted value stored in the NumPy ndarray (yhat), [[xn+1]]. Then, the new input tensor becomes [[[x(n+1), x1, x2, ... , x(n-1)]]]

X[..., -1, 0] = yhat.item(0)

I tried to summarize some of the things I would have liked to know back when I started. I hope you’ll find it useful. Feel free to comment or reach out to me if you agree or disagree with any of the remarks I made above.

1 Like

Thanks for the summary! It depends on the task, but seq2seq models for time series prediction do not give good results based on my experiments.