After countless searches, and putting pieces of the puzzle together, I came up with this code for a “boilerplate” custom image Dataset class. I wanted to ask if this is satisfactorily simple and efficient, or does anyone see where I might possibly run into trouble?
The use case is to quickly, simply and efficiently just bring in whatever images I have stored in a designated folder and have them ready for input, without a lot of fluff.
from torch.utils.data import Dataset, DataLoader from torchvision import transforms, datasets from pathlib2 import Path class MyDataClass(Dataset): def __init__(self, image_path, transform=None): super(MyDataClass, self).__init__() self.data = datasets.ImageFolder(image_path, transform) # Create data from folder def __getitem__(self, idx): x, y = self.data[idx] return x, y def __len__(self): return len(self.data) root = Path.absolute() # Map to dataset folder location img_path = root/'dataset_folder' tsfm = transforms.Compose([ # Misc transforms here. transforms.ToTensor() ]) data = MyDataClass(img_path, tsfm) # create the dataset dataloader = DataLoader(data, batch_size=20)
Aside from extensive customization on my labels, are there any hangups you can potentially see here that I should be aware of, or would I be able to just use this as my boilerplate to prepare any generic png-image dataset to be loaded into a conv2d layer?