Making Predictions Using LSTM with Multivariate Input


I have currently built an lstm for time-series predictions however I am not to sure how to actually make predictions into the future. The data I am using is multivariate (5219,4) where each vector of features was converted via a moving window method. After a window of length 5 is applied, the input vector changes to (5219,5,4) which suits the input requirement of the lstm module. After the model is trained, the resulting plot seems to map the training data fairly accurately. However, I am not sure how to perform predictions.

From what I’ve read on forums, the idea is to use a set sequence of data, the entire training set in my case, and use it to forecast the next day by a simple forward pass through my network. However, instead of one value for an output, the model spits out the entire batch as a prediction. Example below.

input.shape -> (5219,5,4)
output.shape -> (5219,1)

I would expect that given the entire batch, it would predict the 5220th element, since the output of the model is 1 element as per the model definition (number of outputs I defined). Maybe my understanding is incorrect. I have attached a snippet of my model below to help in understanding the problem.

class LSTM(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output, n_layers):
        super(LSTM, self).__init__()
        self.hidden = n_hidden
        self.num_layers = n_layers
        self.rnn = nn.LSTM(n_feature,n_hidden,n_layers,batch_first=True)
        self.relu = nn.ReLU()
        self.output = nn.Linear(n_hidden,n_output)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers,x.size(0),self.hidden) #initial hidden state
        c0 = torch.zeros(self.num_layers,x.size(0),self.hidden)
        x,(h0,c0) = self.rnn(x,(h0.detach(),c0.detach())) #output and hidden state
        x = self.relu(x)
        #decode hidden state of last timestep
        x = x[:,-1,:]
        x = self.output(x)
        return x

model = LSTM(4,10,1,2)

Essentially, given multivariate input, how can I make predictions into the future? Any help with this would be greatly appreciated.