I’m building an LSTM network for text generation and will train it using the first chapter of a book. My understanding is that this is a good use case for a stateful LSTM since this represents document-level prediction. However, I’m having issues understanding what the best structure for batches is and how to create a custom DataLoader
for this purpose.
One approach I’ve though of is to create batches that have a continuous part of the chapter’s text and are shaped (batch_size, timesteps)
(timesteps = words). That is, if you merged the batches by their columns, you get a part of the chapter’s text. Here’s an example of what the batch inputs would look like:
Batch #1
--------------------------------Timesteps --------------------
batch element 1| word_1 | word_2 | word_3 |
batch element 2| word_2 | word_3 | word_4 |
batch element 3| word_3 | word_4 | word_5 |
…
Batch #2
--------------------------------Timesteps --------------------
batch element 1| word_4 | word_5 | word_6 |
batch element 2| word_5 | word_6 | word_7 |
batch element 3| word_6 | word_7 | word_8 |
The target batch would have the timesteps shifted by +1 so the model predicts the next word for the corresponding timestep in the input batch. Note each row in a batch starts +1 timestep from the above row.
I believe this approach would allow to pass the LSTM states between batch iterations, since the i-th row of a given batch has the text that precedes the text of the next batch’s i-th row. The states resulting form the LSTM layer have shape (num_layers, batch_size, hidden_out)
, so there is correspondence across the batch dimension in the inputs and the LSTM states.
I tried to create custom Dataset
and DataLoader
classes for this purpose but have gotten weird results. Specifically, my questions are:
- Is the mentioned batch structure a good approach for creating batches for training a text generation LSTM?
- Are custom classes of
DataLoader
a good approach here or is it better to create a tensor with the input info and iterate over it with indexes in the training function?
Thanks!