Could someone explain batch_first=True in LSTM

The line in the forward() method is

out, _ = self.lstm(x)

So

out[-1]    # If batch_first=True OR
out[:, -1] # If batch_dirst=False

will give you the hidden state after the LAST hidden state with respect to the forward pass but the FIRST hidden state with respect to the backward pass; see this old post of mine. What you want is also the last hidden state with respect to the backward pass. This one cna be found in out[0] or out[:, 0] respectively.

Your network will still learn but essentially making only meaningful use of the forward pass.

The cleanest way is the following

out, (h, c) = self.lstm(x)
# Split num_layers and num_directions (1 or 2)
h = h.view(self.num_layers, self.num_directions, batch_size, self.rnn_hidden_dim)
# Last hidden state w.r.t. to then number of layers, NOT the time steps
h = h[-1]   
# Get last hidden states of backward and forward passes
h_forward, h_backward = h[0], h[1]

Then you can add or concatenate h_foward and h_backward for further layers.

2 Likes