Batch size handling with LSTM

Hi All,
I am trying to implement custom LSTM layer with custom cell.
It is working OK when I pass only one sample, but when I want to pass a batch of data a problem appear.
For example:
If my data look like [128, 64] where 128 is max sequence length and 64 is word embedding size, the model works fine
I want to know how to handle a batch of data e.g. [16, 128, 64] ?
I search for the implementation of the lstm cell in pytorch , but I can’t find the code that handles the batch

and this code leads to this impl.

is pytorch for loop the data batch?
Thanks in advance.

1 Like

Any help will be appreciated @ptrblck

@Ahmed_Akl , You can do it ,but keep in my that the states such as the hidden and the cell states rely on the batch size. Understanding this video might help answer your question torch.nn.RNN Module explained - YouTube