The line in the forward()
method is
out, _ = self.lstm(x)
So
out[-1] # If batch_first=True OR
out[:, -1] # If batch_dirst=False
will give you the hidden state after the LAST hidden state with respect to the forward pass but the FIRST hidden state with respect to the backward pass; see this old post of mine. What you want is also the last hidden state with respect to the backward pass. This one cna be found in out[0]
or out[:, 0]
respectively.
Your network will still learn but essentially making only meaningful use of the forward pass.
The cleanest way is the following
out, (h, c) = self.lstm(x)
# Split num_layers and num_directions (1 or 2)
h = h.view(self.num_layers, self.num_directions, batch_size, self.rnn_hidden_dim)
# Last hidden state w.r.t. to then number of layers, NOT the time steps
h = h[-1]
# Get last hidden states of backward and forward passes
h_forward, h_backward = h[0], h[1]
Then you can add or concatenate h_foward
and h_backward
for further layers.