How to handle Pytorch Dataset with transform function that returns >1 output per row of data?

Given a myfile.csv file that looks like:

imagefile,label
train/0/16585.png,0
train/0/56789.png,0

The goal is to create a Pytorch DataLoader that when looped return 2x the data points, e.g.

>>> dp = MyDataPipe(csvfile)
>>> for row in dp.train_dataloader:
...     print(row)
...
(tensor([1.23, 4.56, 7.89]), 0)
(tensor([9.87, 6.54, 3.21]), 1)
(tensor([9.99, 8.88, 7.77]), 0)
(tensor([1.11, 2.22, 9.87]), 1)

I’ve tried writing the dataloader if we are just expect the same no. of dataloader’s row as per the input file, this works:

import torch 

from torch.utils.data import DataLoader2
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
import pytorch_lightning as pl


content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""

with open('myfile.csv', 'w') as fout:
    fout.write(content)


def optimus_prime(row):
    """This functions returns two data points with some arbitrary vectors.
    >>> row = {'imagefile': 'train/0/16585.png', label: 0}
    >>> optimus_prime(row)
    (tensor([1.23, 4.56, 7.89]), 0)
    """
    # We are using torch.rand here but there is an actual function
    # that converts the png file into a vector.
    vector1 = torch.rand(3) 
    return vector1, row['label']
    

class MyDataPipe(pl.LightningDataModule):
    def __init__(
        self,
        csv_files: list[str],
        skip_lines: int = 0,
        tranform_func: Callable = None
    ):
        super().__init__()
        self.csv_files: list[str] = csv_files
        self.skip_lines: int = skip_lines

        # Initialize a datapipe.
        self.dp_chained_datapipe: IterDataPipe = (
            IterableWrapper(iterable=self.csv_files)
            .open_files()
            .parse_csv_as_dict(skip_lines=self.skip_lines)
        )
            
        if tranform_func:
            self.dp_chained_datapipe = self.dp_chained_datapipe.map(tranform_func)

    def train_dataloader(self, batch_size=1) -> DataLoader2:
        return DataLoader2(dataset=self.dp_chained_datapipe, batch_size=batch_size)

dp = MyDataPipe('myfile.csv', tranform_func=optimus_prime)

for row in dp.train_dataloader:
    print(row)

If the optimus_prime function returns 2 data points, how do I setup the Dataloader such that it can collate the 2 data points accordingly?

How to formulate the collate function or tell the Dataloader that there’s 2 inputs in each .map(tranform_func) output? E.g. if I change the function to:

def optimus_prime(row):
    """This functions returns two data points with some arbitrary vectors.
    >>> row = {'imagefile': 'train/0/16585.png', label: 0}
    >>> optimus_prime(row)
    (tensor([1.23, 4.56, 7.89]), 0), (tensor([3.21, 6.54, 9.87]), 1)
    """
    # We are using torch.rand here but there is an actual function
    # that converts the png file into a vector.
    vector1 = torch.rand(3) 
    yield vector1, row['label']
    yield vector2, not row['label']

I’ve also tried the following and it works but I need to run the optimus_prime function twice, but the 2nd .map(tranform_func) throws a TypeError: tuple indices must be integers or slice not str


def optimus_prime_1(row):
    # We are using torch.rand here but there is an actual function
    # that converts the png file into a vector.
    vector1 = torch.rand(3) 
    yield vector1, row['label']

def optimus_prime_2(row):
    # We are using torch.rand here but there is an actual function
    # that converts the png file into a vector.
    vector2 = torch.rand(3) 
    yield vector2, not row['label']
    

class MyDataPipe(pl.LightningDataModule):
    def __init__(
        self,
        csv_files: list[str],
        skip_lines: int = 0,
        tranform_funcs: list[Callable] = None
    ):
        super().__init__()
        self.csv_files: list[str] = csv_files
        self.skip_lines: int = skip_lines

        # Initialize a datapipe.
        self.dp_chained_datapipe: IterDataPipe = (
            IterableWrapper(iterable=self.csv_files)
            .open_files()
            .parse_csv_as_dict(skip_lines=self.skip_lines)
        )
            
        if tranform_funcs:
            for tranform_func in tranform_funcs:
                self.dp_chained_datapipe = self.dp_chained_datapipe.map(tranform_func)

    def train_dataloader(self, batch_size=1) -> DataLoader2:
        return DataLoader2(dataset=self.dp_chained_datapipe, batch_size=batch_size)

dp = MyDataPipe('myfile.csv', tranform_funcs=[optimus_prime_1, optimus_prime_2])

for row in dp.train_dataloader:
    print(row)

BTW, this is also asked on python - How to handle Pytorch Dataset with transform function that returns >1 output per row of data? - Stack Overflow

If you want to create 2x the data point, I recommend using .flatmap(fn) instead of .map(fn) after .parse_csv_as_dict(). Then you can shuffle and batch as needed afterwards.

This is assuming you don’t require those two data points grouped together in the same batch.

Thanks for the suggestion, I’m looking at FlatMapper — TorchData main documentation and not exactly sure how is it different from Mapper — TorchData main documentation

But I’ve managed to get it to work with .flatmap(), python - How to handle Pytorch Dataset with transform function that returns >1 output per row of data? - Stack Overflow


It looks like it’s doing some sort of lambda x: chain(*map(tranform_func(x)) instead of the lambda x: map(transform_func(x)) but I’m not exactly sure I’m understanding it correctly.

Would FlatMapper .flatmap() read through the whole dataset? Or just the datapoint and then do a map one-by-one on the fly like .map()?

flatmap reads one data point at a time, apply the given transformation, then flattens the result of the transformation. map does the same but without the flattening.

A common use case is to take in a data point, generate variations of the same input, and return those variations in a list:

def create_two_versions(x: str):
    return [x.lower(), x.upper()]

dp = IterableWrapper(["Apple", "bAnAnA"])
flatmap_dp = dp.flatmap(create_two_versions)
print(list(flatmap_dp))  # ['apple', 'APPLE', 'banana', 'BANANA']
# Because of the additional flatten operation, each list is turned into data points.
# You get 4 data points from 2 original data points.
# You can perform `.shuffle`, `.batch`, or any other transformation after this.

map_dp = dp.map(create_two_versions)
print(list(map_dp))  # [['apple', 'APPLE'], ['banana', 'BANANA']]
# The lack of flattening means each data point is transformed into a list.
# You still only have two data points.
1 Like