Could someone explain batch_first=True in LSTM

I can’t figure out how it works. I try to change bs,time step, input size with batch_first = True or no batch_first. It return the dimension that I feed to model. so I have to change it manually ??

3 Likes

If your input data is of shape (seq_len, batch_size, features) then you don’t need batch_first=True and your LSTM will give output of shape (seq_len, batch_size, hidden_size).

If your input data is of shape (batch_size, seq_len, features) then you need batch_first=True and your LSTM will give output of shape (batch_size, seq_len, hidden_size).

Is that not what you expect?

11 Likes

but when my input’s shape is (batch_size, seq_len, features) without batch_first=True the output still as same as the input that why i had a little doubt

1 Like
no batch_first=True


https://i.imgur.com/51SRQni.png

had batch_first=True

had%20batch_first
https://i.imgur.com/PI0Excd.png

If you feed it input of shape (10, 20, in_features) then the output will always be of shape (10, 20, hidden_size)

Without batch_first=True it will use the first dimension as the sequence dimension.
With batch_first=True it will use the second dimension as the sequence dimension.

1 Like

It does not work the result should become 3x2

I really appreciate that your reply to all my topic thanks alot

When you do

print(out[-1])

you are taking the last element of the batch dimension.

You probably wanted to do

print(out[:, -1])
3 Likes

Note that both version are only meaningful when bidirectional=False.

2 Likes

Hi @vdw Chris, would you mind to explain a bit more why this doesn’t matter if bidirectional is True. I recently found out that my model was trained with batch first sequence but the bilstm is set to batch_first = False. And surprisingly, the model still converge. I wish I could to correct it without retrain the model.

The line in the forward() method is

out, _ = self.lstm(x)

So

out[-1]    # If batch_first=True OR
out[:, -1] # If batch_dirst=False

will give you the hidden state after the LAST hidden state with respect to the forward pass but the FIRST hidden state with respect to the backward pass; see this old post of mine. What you want is also the last hidden state with respect to the backward pass. This one cna be found in out[0] or out[:, 0] respectively.

Your network will still learn but essentially making only meaningful use of the forward pass.

The cleanest way is the following

out, (h, c) = self.lstm(x)
# Split num_layers and num_directions (1 or 2)
h = h.view(self.num_layers, self.num_directions, batch_size, self.rnn_hidden_dim)
# Last hidden state w.r.t. to then number of layers, NOT the time steps
h = h[-1]   
# Get last hidden states of backward and forward passes
h_forward, h_backward = h[0], h[1]

Then you can add or concatenate h_foward and h_backward for further layers.

2 Likes