Hi,
I’ve been toying with the new TorchData lib, and I came up with an implementation that seems overly complicated for something pretty simple, so I wonder if i missed a simpler way of doing things.
Let’s say I’m doing semantic segmentation and I have a dataset folder containing two subfolders ‘images’ and ‘masks’ with identically named png files. I want to load a couple (image, mask) as PIL Image and augment the couple. That would be the old getitem() method. Then I want to split into train and validation sets, shuffle and batch. That would be the old DataLoader.
Currently I’m doing this:
image_files = FileLister([rootdir / 'images'], masks='*.png')
mask_files = FileLister([rootdir / 'masks'], masks='*_multiclass.png')
images = image_files.map(Image.open)
masks = mask_files.map(Image.open)
dp = Zipper(images, masks).map(transform)
def train_val_split(x: dict) -> bool: return x['index'] >= split * length
train_dp, val_dp = dp.add_index().demux(num_instances=2, classifier_fn=train_val_split)
train_dp = train_dp.shuffle().batch(batch_size).collate(default_collate)
val_dp = val_dp.shuffle().batch(batch_size).collate(default_collate)
where split
is a float in [0, 1], length
is the number of couples (image, mask), transform
is a function returning a dict[‘image’: tensor[3, H, W], ‘mask’: tensor[H, W]], and default_collate
is torch.utils.data.default_collate
.
So is this indeed too complicated and i there a cleaner way of doing things ?
Thanks !