Production of LSTM example

Hi everyone,

I am trying to code a very simple LSTM, below how I defined the main class:

class lstm_mdl(nn.Module):
    def __init__(self, x, n_nrns, nl, y):
        super(lstm_mdl, self).__init__()
        self.lstm = nn.LSTM(input_size = x, hidden_size = n_nrns, num_layers = nl, batch_first=True)
        self.linear = nn.Linear(n_nrns, y)
    def forward(self, x):
        y, _ = self.lstm(x)
        x = self.linear(y[:, -1, :])
        return x

Now, assuming three time-series y, A and B: we want to predict y with the time-series A and B. The time-series have 10 observations and we want a lookback of 1 period (2 periods per-batch).
In doing this, we run the following code on the two tensors y and (A + B), to get the 3D tensors we use in order to train and compute deviations:

x = tc.tensor(A_B, dtype = tc.float32) #torch.Size([10, 2])
y = tc.tensor(y.reshape(-1, 1), dtype = tc.float32) #torch.Size([10, 1])

data_size = x.shape[0] - lookback

x_train = tc.zeros(data_size, lookback, x.shape[1])
y_train = tc.zeros(data_size, 1)
for i in range(data_size):
    x_train[i] = x[i:(i + lookback)] #torch.Size([8, 2, 2]) as the loop ends
    y_train[i] = y[i + lookback] #torch.Size([8, 1]) as the loop ends

This is going to provide use the usual 3D input-tensor, whose each batch represents observations in the range [t - (t-lookback)], [t-1, (t-loockback-1)] … and so far and so on. Thus, we get the usual sequences that contain moving windows extrapolated from the original dataset.

Finally, this is the loop for training:

for epoch in range(n_epochs):   
    epoch_loss = []
    for i in range(0, data_size, batch_size):
        input = x_train[i:(i + batch_size)]
        output = y_train[i:(i + batch_size)]
        pred = ann(input)
        current_loss = loss(pred, output) 

The loss decreases, but when I try tu run this I get weird results:
f_pred = ann(x_train) #torch.Size([8, 2, 2])

I would expect a prediction for each batch, so 8 different forecasts in total.
For each batch the forecast is exactly the same instead (I can provide numbers if needed).

Can you help with my code?