FlatMapDataset - transforming a map-style dataset into an iterable-style dataset

I sketched the concept of flat-map operation for PyTorch datasets - FlatMapDataset. You provide it with a regular map-style dataset and a transformation function that yields a sequence. When iterated over, it fetches an element of the source Dataset and pass it to your transformation function. All results of the function are being continously concatenated into stream of examples (i.e. IterableDataset). It can be passed to a DataLoader to form the training batches.

Here is an example:

source_dataset = SourceDataset()
source_train, source_valid = torch.utils.data.random_split(
        source_dataset, [9000, 1000])

def crop_boxes(example):
    image, boxes = example
    for box in boxes:
        yield crop_image(image, box)

train_dataset = FlatMapDataset(source_train, crop_boxes)
data_loader = DataLoader(train_dataset)

It is supposed to be used in situations where one file have to be split into an unknown number of actual training examples, e.g. when you’d like to extract multiple crops of a single training image to train an autoencoder / GAN. Sometimes a large pre-processing is not suitable in such a setting.


1 Like