I am trying to code a very simple LSTM, below how I defined the main class:
class lstm_mdl(nn.Module): def __init__(self, x, n_nrns, nl, y): super(lstm_mdl, self).__init__() self.lstm = nn.LSTM(input_size = x, hidden_size = n_nrns, num_layers = nl, batch_first=True) self.linear = nn.Linear(n_nrns, y) def forward(self, x): y, _ = self.lstm(x) x = self.linear(y[:, -1, :]) return x
Now, assuming three time-series y, A and B: we want to predict y with the time-series A and B. The time-series have 10 observations and we want a lookback of 1 period (2 periods per-batch).
In doing this, we run the following code on the two tensors y and (A + B), to get the 3D tensors we use in order to train and compute deviations:
x = tc.tensor(A_B, dtype = tc.float32) #torch.Size([10, 2]) y = tc.tensor(y.reshape(-1, 1), dtype = tc.float32) #torch.Size([10, 1]) data_size = x.shape - lookback x_train = tc.zeros(data_size, lookback, x.shape) y_train = tc.zeros(data_size, 1) for i in range(data_size): x_train[i] = x[i:(i + lookback)] #torch.Size([8, 2, 2]) as the loop ends y_train[i] = y[i + lookback] #torch.Size([8, 1]) as the loop ends
This is going to provide use the usual 3D input-tensor, whose each batch represents observations in the range [t - (t-lookback)], [t-1, (t-loockback-1)] … and so far and so on. Thus, we get the usual sequences that contain moving windows extrapolated from the original dataset.
Finally, this is the loop for training:
for epoch in range(n_epochs): epoch_loss =  for i in range(0, data_size, batch_size): input = x_train[i:(i + batch_size)] output = y_train[i:(i + batch_size)] pred = ann(input) current_loss = loss(pred, output) optimizer.zero_grad() current_loss.backward() optimizer.step() epoch_loss.append(np.sqrt(current_loss.item())) rec_losses.append(np.array(epoch_loss).mean())
The loss decreases, but when I try tu run this I get weird results:
f_pred = ann(x_train) #torch.Size([8, 2, 2])
I would expect a prediction for each batch, so 8 different forecasts in total.
For each batch the forecast is exactly the same instead (I can provide numbers if needed).
Can you help with my code?