After countless searches, and putting pieces of the puzzle together, I came up with this code for a “boilerplate” custom image Dataset class. I wanted to ask if this is satisfactorily simple and efficient, or does anyone see where I might possibly run into trouble?
The use case is to quickly, simply and efficiently just bring in whatever images I have stored in a designated folder and have them ready for input, without a lot of fluff.
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from pathlib2 import Path
class MyDataClass(Dataset):
def __init__(self, image_path, transform=None):
super(MyDataClass, self).__init__()
self.data = datasets.ImageFolder(image_path, transform) # Create data from folder
def __getitem__(self, idx):
x, y = self.data[idx]
return x, y
def __len__(self):
return len(self.data)
root = Path.absolute() # Map to dataset folder location
img_path = root/'dataset_folder'
tsfm = transforms.Compose([
# Misc transforms here.
transforms.ToTensor()
])
data = MyDataClass(img_path, tsfm) # create the dataset
dataloader = DataLoader(data, batch_size=20)
Aside from extensive customization on my labels, are there any hangups you can potentially see here that I should be aware of, or would I be able to just use this as my boilerplate to prepare any generic png-image dataset to be loaded into a conv2d layer?