Simple, efficient way to create Dataset?

After countless searches, and putting pieces of the puzzle together, I came up with this code for a “boilerplate” custom image Dataset class. I wanted to ask if this is satisfactorily simple and efficient, or does anyone see where I might possibly run into trouble?

The use case is to quickly, simply and efficiently just bring in whatever images I have stored in a designated folder and have them ready for input, without a lot of fluff.

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from pathlib2 import Path

class MyDataClass(Dataset):
    def __init__(self, image_path, transform=None):
        super(MyDataClass, self).__init__()
        self.data = datasets.ImageFolder(image_path,  transform)    # Create data from folder

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x, y

    def __len__(self):
        return len(self.data)

root = Path.absolute()        # Map to dataset folder location
img_path = root/'dataset_folder'

tsfm = transforms.Compose([
        # Misc transforms here.
        transforms.ToTensor()
])

data = MyDataClass(img_path, tsfm) # create the dataset

dataloader = DataLoader(data, batch_size=20)

Aside from extensive customization on my labels, are there any hangups you can potentially see here that I should be aware of, or would I be able to just use this as my boilerplate to prepare any generic png-image dataset to be loaded into a conv2d layer?

@ptrblck I set this up and it works fine, I just wanted to know if I’m missing anything or doing anything wrong since a simple solution like this doesn’t seem to be anywhere in tutorials or the documentation

Your custom Dataset should work fine. However, currently you could also just use ImageFolder directly without wrapping it in Dataset and pass it to a DataLoader.
If you want to customize the samples inside __getitem__, your approach looks good! :slight_smile:

1 Like

Awesome, thanks. So for reference, one can simply write:

data = datasets.ImageFolder(img_path, tsfm) 

And pass that straight into a DataLoader. That’s got to be by far the simplest, no fuss way to create a dataset in Pytorch. For anyone finding this thread, you must convert your images to tensors with the following transform if you want to utilize it in most of the nn.Modules, I believe.

transforms.ToTensor()

Yes, that’s correct. Implementing a custom Dataset is useful in case you need some specific indexing, processing, etc. If your class images are stored in folders, you can simply use Imagefolder.