LSTM - Problem with multidimensional forecasting

Hey everyone,

I am currently working on a lstm which should predict a multidimensional output given any input.
To clarify what I mean, I have the following setup:

Input: [batch_size, seq_len, n_features] = [32, 16, 2]


    def __init__(self, in_features: int, n_hidden: int, out_features: int, num_layers: int):
        super(LSTMModel, self).__init__()
        self.in_features = in_features
        self.n_hidden = n_hidden
        self.out_features = out_features
        self.num_layers = num_layers

        self.lstm1 = nn.LSTM(
        self.linear = nn.Linear(

    def forward(self, x: torch.FloatTensor) -> torch.Tensor:
        batch_size = x.size(0)

        h_t = torch.zeros(self.num_layers, batch_size, self.n_hidden, device=x.device).float()
        c_t = torch.zeros(self.num_layers, batch_size, self.n_hidden, device=x.device).float()

        x, (h_t, c_t) = self.lstm1(x, (h_t, c_t))
        x = self.linear(x[:, -1, :])
        return x

The Output should be:
[batch_size, target_seq_len, target_n_features] = [32, 4, 2]

The reason for this setup: I want to ‘predict target_seq_len’, where every target is a vector with ‘target_n_features’ given a history (sequence) with the setting (desc. above)

So I made a multidimensional forecasting… But the problem is, that the code I use for this task only gives me an output of [batch_size, target_seq_len]. What can I do to solve the problem?

I use standard PyTorch and pytorch-lighnting for my problem.

Here is also a screenshot of the shapes (console): The first line is the y_pred and the second is the y_true

Screenshot 2022-07-05 164054

Do you have any ideas for code-changes or solutions?

~ Linus

You are explicitly using the last time step of the LSTM output only in this line of code:

x = self.linear(x[:, -1, :])

and are thus removing the temporal dimension.
You can directly pass the entire tensor to the linear layer which would then apply the layer to each time step.

Thank you for your answer. That solved my problem completely! :slight_smile: