Why are transforms given to the dataset and not to the dataloader?

Hey team,

as a PyTorch novice, I’m deeply impressed by how clearly standardized and separated all the different elements of a deep learning pipeline are (e.g. dataset, transforms, data loader). However, I don’t quite understand why the transforms are specified when creating a dataset opposed to giving them as a parameter to the data loader that follows.

Right now I have to create a transformation pipeline before instantiating the dataset itself. Looking at my notebook this seems a little like doing things in reverse. Wouldn’t it feel more natural to first specify the dataset and then follow with the transformations I want to do on them?

This would also allow an easier workflow with images. dataset[i] would return the i-th PIL image, which is rendered as an image per default in Jupyter opposed to returning a torch tensor which doesn’t do much for me until the training starts.

Specifying the transforms to the data loader would allow a cleaner workflow regarding training/validation splits and different types of transformations I want to run on each. Right now I have to create the same dataset twice but with different transformations, create 2 subset-samplers that I provide to the 2 data loaders. If transforms were given to the loader there would only be 1 dataset object that is used to create 2 subsets, which each gets their own train/validation transformations.

Does this make sense?

Although my post sounds like I’m suggesting a change, I’m surely just overlooking something in plain sight. I would like to know what that it.



I see the advantages of your approach.
However, I think one of the disadvantages would be that all the transformations would have to be applied at the end of __getitem__ (or after it).
For a lot of use cases, this might be a valid approach. However, if you would like to apply (some) transformations after loading your data, then perform something fancy like a sliding window approach etc., and finally apply another transformation, the work flow might break and I’m not sure if it’s worth it.

An easy workaround for your use case would be to write a get_sample(self, index) method, which just loads a single sample and use it in __getitem__.
Your Dataset would look like this:

class MyDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def get_sample(self, index):
        img = Image.open(self.paths[index])
        return img

    def __getitem__(self, index):
        x = self.get_sample(index)
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        return len(self.paths)

Using this approach, you could call dataset.get_sample(0) to get the plain sample without any transformations in case you would like to have a look at the data in your notebook.

1 Like

The get_sample() function is actually a cool workaround. I’ll try that coupled with a transforms argument for a custom data loader.

Thanks for your time!