Batching and outputting multiple objects with IterableDataset

Hi All,

I have a very specific application where I use grammar rules to generate random sequences of tokens (up to a maximum length), and train a VAE with these tokens, as in I generate sequences on the fly in an IterableDataset, which converts the token sequences to collections of one-hot vectors (one per sequence timestep). Note that the VAE is defined as a Pytorch-Lightning model. Here’s the outline of my custom dataset class:

class SentenceGenerator(IterableDataset):
    def __init__(self, grammar, min_sample_len, max_sample_len, batch_size=256, seed=0):

    def generate_sentence(self):
        # generates a string of tokens ('sentence') using the rules encoded in self.grammar
        return ''.join(sent)

    def generate_one_hots(self):  
        # converts 'sentences' to sequences of one-hot vectors
        self.sents = [self.generate_sentence() for _ in range(self.batch_sz)]
        out = make_one_hot(self.grammar, self.tokenizer, self.prod_map, self.sents, max_len=self.max_len,
        return out.transpose(2, 1)  # (batch_size, vocab_size, max_length)

    def __iter__(self):
        return iter(self.generate_one_hots()) 
  1. Batching: For some reason, generating a single matrix of one-hot vectors and letting the DataLoader batch them didn’t work; the batches were always of size one, plus I was having other downstream issues with Pytorch-Lightning. Therefore I resorted to handling batching directly within SentenceGenerator, as you can see above. In the DataLoader, I then have to specify the same batch size as in the Dataset for batches to be generated. It is a bit hacky and causes some headaches again downstream in terms of understanding what an epoch is, when to step the LR scheduler, when to log a result, etc. Is there a more elegant way of achieving the same result?

  2. Multiple outputs from the dataset: In order to help the VAE train, I would like to explicitly pass it the sentence lengths. This cannot be directly derived from the matrix of one-hot vectors because each token can take a variable number of timesteps to produce. Therefore, I need to output this value directly from the IterableDataset like so:

    def generate_one_hots(self):
        ...  # same as above except for `return` statement
        return out.transpose(2, 1), self.sent_lengths  # now returning 2 tensors

    def _iter__(self):
         return iter(self.generate_one_hots())

… and then upack each batch in vae.forward() as x, n = batch where n contains the lengths of the ‘sentences’ represented by the tensor of one-hot vectors x. However, in this implementation, only the first object returned by generate_one_hots() is included in batch and I get a “too few values to unpack” error. Would you have any suggestions? Ideally one that solves both issues at once!

Thanks a million in advance.