I’ve been toying with the new TorchData lib, and I came up with an implementation that seems overly complicated for something pretty simple, so I wonder if i missed a simpler way of doing things.
Let’s say I’m doing semantic segmentation and I have a dataset folder containing two subfolders ‘images’ and ‘masks’ with identically named png files. I want to load a couple (image, mask) as PIL Image and augment the couple. That would be the old getitem() method. Then I want to split into train and validation sets, shuffle and batch. That would be the old DataLoader.
Currently I’m doing this:
image_files = FileLister([rootdir / 'images'], masks='*.png') mask_files = FileLister([rootdir / 'masks'], masks='*_multiclass.png') images = image_files.map(Image.open) masks = mask_files.map(Image.open) dp = Zipper(images, masks).map(transform) def train_val_split(x: dict) -> bool: return x['index'] >= split * length train_dp, val_dp = dp.add_index().demux(num_instances=2, classifier_fn=train_val_split) train_dp = train_dp.shuffle().batch(batch_size).collate(default_collate) val_dp = val_dp.shuffle().batch(batch_size).collate(default_collate)
split is a float in [0, 1],
length is the number of couples (image, mask),
transform is a function returning a dict[‘image’: tensor[3, H, W], ‘mask’: tensor[H, W]], and
So is this indeed too complicated and i there a cleaner way of doing things ?