Best practices for batching data for stateful LSTM and text generation

I’m building an LSTM network for text generation and will train it using the first chapter of a book. My understanding is that this is a good use case for a stateful LSTM since this represents document-level prediction. However, I’m having issues understanding what the best structure for batches is and how to create a custom DataLoader for this purpose.

One approach I’ve though of is to create batches that have a continuous part of the chapter’s text and are shaped (batch_size, timesteps) (timesteps = words). That is, if you merged the batches by their columns, you get a part of the chapter’s text. Here’s an example of what the batch inputs would look like:

Batch #1
--------------------------------Timesteps --------------------
batch element 1| word_1 | word_2 | word_3 |
batch element 2| word_2 | word_3 | word_4 |
batch element 3| word_3 | word_4 | word_5 |

Batch #2
--------------------------------Timesteps --------------------
batch element 1| word_4 | word_5 | word_6 |
batch element 2| word_5 | word_6 | word_7 |
batch element 3| word_6 | word_7 | word_8 |

The target batch would have the timesteps shifted by +1 so the model predicts the next word for the corresponding timestep in the input batch. Note each row in a batch starts +1 timestep from the above row.

I believe this approach would allow to pass the LSTM states between batch iterations, since the i-th row of a given batch has the text that precedes the text of the next batch’s i-th row. The states resulting form the LSTM layer have shape (num_layers, batch_size, hidden_out), so there is correspondence across the batch dimension in the inputs and the LSTM states.

I tried to create custom Dataset and DataLoader classes for this purpose but have gotten weird results. Specifically, my questions are:

  1. Is the mentioned batch structure a good approach for creating batches for training a text generation LSTM?
  2. Are custom classes of DataLoader a good approach here or is it better to create a tensor with the input info and iterate over it with indexes in the training function?


Yes, this organization of batches should be the correct way to go if you want to “preserve” the hidden state between batches. You probably want to detach() the hidden state though to avoid your computational graph exploding.

You can probably implement this using a Dataset and a Sampler. Here is a basic example that uses this approach to create batches where all sequences in a batch are guaranteed to have the same length. As you can see, the relevant logic is in the Sampler class.

Thanks @vdw! I’ll keep in mind your example with Sampler for future implementations. I actually managed to implement this in a different way with an approach I found on a book: They took the encoded text (list of embedding indexes representing the text) and divide it into chunks, with each chunk sliding +1 word with respect to the previous one. The chunks are sized seq_len +1, so that each chunk contains the input sequence for the model and the target sequence (input: [0:-1], target: [1:]). A Dataset custom class retrieves a chunk from the chunks list by index and returns a tuple with the input and target for the model. This custom Dataset is passed to DataLoader with drop_last=True, generating mini-batches that are contiguous as I described above.