Could someone explain batch_first=True in LSTM

I can’t figure out how it works. I try to change bs,time step, input size with batch_first = True or no batch_first. It return the dimension that I feed to model. so I have to change it manually ??


If your input data is of shape (seq_len, batch_size, features) then you don’t need batch_first=True and your LSTM will give output of shape (seq_len, batch_size, hidden_size).

If your input data is of shape (batch_size, seq_len, features) then you need batch_first=True and your LSTM will give output of shape (batch_size, seq_len, hidden_size).

Is that not what you expect?


but when my input’s shape is (batch_size, seq_len, features) without batch_first=True the output still as same as the input that why i had a little doubt

no batch_first=True

had batch_first=True


If you feed it input of shape (10, 20, in_features) then the output will always be of shape (10, 20, hidden_size)

Without batch_first=True it will use the first dimension as the sequence dimension.
With batch_first=True it will use the second dimension as the sequence dimension.

It does not work the result should become 3x2

When you do


you are taking the last element of the batch dimension.

You probably wanted to do

print(out[:, -1])

Note that both version are only meaningful when bidirectional=False.


Hi @vdw Chris, would you mind to explain a bit more why this doesn’t matter if bidirectional is True. I recently found out that my model was trained with batch first sequence but the bilstm is set to batch_first = False. And surprisingly, the model still converge. I wish I could to correct it without retrain the model.

The line in the forward() method is

out, _ = self.lstm(x)


out[-1]    # If batch_first=True OR
out[:, -1] # If batch_dirst=False

will give you the hidden state after the LAST hidden state with respect to the forward pass but the FIRST hidden state with respect to the backward pass; see this old post of mine. What you want is also the last hidden state with respect to the backward pass. This one cna be found in out[0] or out[:, 0] respectively.

Your network will still learn but essentially making only meaningful use of the forward pass.

The cleanest way is the following

out, (h, c) = self.lstm(x)
# Split num_layers and num_directions (1 or 2)
h = h.view(self.num_layers, self.num_directions, batch_size, self.rnn_hidden_dim)
# Last hidden state w.r.t. to then number of layers, NOT the time steps
h = h[-1]   
# Get last hidden states of backward and forward passes
h_forward, h_backward = h[0], h[1]

Then you can add or concatenate h_foward and h_backward for further layers.

