Confusion regarding PyTorch LSTMs compared to Keras stateful LSTM

I was running into a similar issue with Pytorch vs Keras. But then I realized that in Keras when you set stateful=True, you are essentially making a longer sequence of your data with batch size=1

For example, say X is of shape B,L,H where B is the batch size, L is the sequence length, and H is the hidden dim, then in Keras LSTM with stateful=True, this will be same as having a batch size of 1 and concatenating one by one all the seq. lengths so they will now be of length BL, i.e. input X is now of shape 1,LB,H

And so by reshaping your input data, you get the same behavior. And this can be done easily in Pytorch.

And in theory there should be no difference in space and time complexity between the two approaches because once you set Stateful=True in Keras, it will have to sequentially process each batch one at a time starting from batch 0 to batch B (i.e. it can’t process the batches in parallel anymore) because you need the final hidden state from batch b0 as initial hidden state for batch b1, and so forth for subsequent batches.

Hope this helps any future reader running into this.

2 Likes