Hi everyone,
I am trying to code a very simple LSTM, below how I defined the main class:
class lstm_mdl(nn.Module):
def __init__(self, x, n_nrns, nl, y):
super(lstm_mdl, self).__init__()
self.lstm = nn.LSTM(input_size = x, hidden_size = n_nrns, num_layers = nl, batch_first=True)
self.linear = nn.Linear(n_nrns, y)
def forward(self, x):
y, _ = self.lstm(x)
x = self.linear(y[:, -1, :])
return x
Now, assuming three time-series y, A and B: we want to predict y with the time-series A and B. The time-series have 10 observations and we want a lookback of 1 period (2 periods per-batch).
In doing this, we run the following code on the two tensors y and (A + B), to get the 3D tensors we use in order to train and compute deviations:
x = tc.tensor(A_B, dtype = tc.float32) #torch.Size([10, 2])
y = tc.tensor(y.reshape(-1, 1), dtype = tc.float32) #torch.Size([10, 1])
data_size = x.shape[0] - lookback
x_train = tc.zeros(data_size, lookback, x.shape[1])
y_train = tc.zeros(data_size, 1)
for i in range(data_size):
x_train[i] = x[i:(i + lookback)] #torch.Size([8, 2, 2]) as the loop ends
y_train[i] = y[i + lookback] #torch.Size([8, 1]) as the loop ends
This is going to provide use the usual 3D input-tensor, whose each batch represents observations in the range [t - (t-lookback)], [t-1, (t-loockback-1)] … and so far and so on. Thus, we get the usual sequences that contain moving windows extrapolated from the original dataset.
Finally, this is the loop for training:
for epoch in range(n_epochs):
epoch_loss = []
for i in range(0, data_size, batch_size):
input = x_train[i:(i + batch_size)]
output = y_train[i:(i + batch_size)]
pred = ann(input)
current_loss = loss(pred, output)
optimizer.zero_grad()
current_loss.backward()
optimizer.step()
epoch_loss.append(np.sqrt(current_loss.item()))
rec_losses.append(np.array(epoch_loss).mean())
The loss decreases, but when I try tu run this I get weird results:
f_pred = ann(x_train)
#torch.Size([8, 2, 2])
I would expect a prediction for each batch, so 8 different forecasts in total.
For each batch the forecast is exactly the same instead (I can provide numbers if needed).
Can you help with my code?