Batch Examples in Document

I am working on a document-based dataset where each sentence is a sample (torch Example). My target is to create an iterator where each batch represents a document, sentences as samples where I can have a dynamic batch sizes depending on the number of sentences in the document.

In the normal case, I would have used the below

    data_iter = data.Iterator(
        dataset,
        config.batch_size,
        repeat=False
    )

In my case now I can utilize batch_size_fn Iterator parameter but that won’t give me access to previous sample in the batch so that I can access and compare the document ids and only add to batch if it matches. I thought about creating a wrapper iterator but no idea how to get this done. Any ideas would be appreciated!

You can always write your own Dataset and Sampler to accommodate your needs. For example, here I wrote one that ensures that all sequences in a batch have the same length.

Apart from that, why would you want to do that? Is there any reasons to have all sentences in one document together as batch. Having batches of (very) different sizes can also be problematic. With larger batches you can have higher learning rates which can be problematic for smaller batches. That’s because the more samples you have in a batch, the more the gradients average out. See this paper.

Thanks for your response. I will check your code. And to answer your question, I want to use that sampling approach in my evaluation script In which, I use the iterator to generate document as a batch, process all sentences in that batch and evaluate/predict the batch output as a single document all at once. I am not sure if this is the way to go, but I am just experimenting. Or do you have a better suggestion?

I solved my problem by overriding the create_batches methods inside the data.Iterator.

class EvalIterator(data.Iterator):
    def create_batches(self):
        self.batches = custom_batch(self.data())


def custom_batch(data):
    """Yield elements from data where each batch represents a sentence."""
    minibatch = []
    old_doc_id = None

    for ex in data:
        minibatch.append(ex)

        if not old_doc_id:
            old_doc_id = ex.doc_id

        if old_doc_id != ex.doc_id:
            old_doc_id = ex.doc_id
            yield minibatch[:-1]
            minibatch = minibatch[-1:]

    if minibatch:
        yield minibatch

I’m just not sure why it is so important to have all sentences of the same document in the same batch.

The normal approach is to split all documents into sentences, preprocess all sentences, and then use the whole set of sentences for training/validating/testing. Here, you can simple fix a batch size of say 32 or 64. Each batch can have sentences from more than one document, or even from 32/64 documents if you shuffle your dataset of sentences first.