I am new to PyTorch and am attempting to load a custom dataset composed only of images. I’m not sure where I’m going wrong, could someone please help me to understand. Thanks in advance!
#!/usr/bin/env python
import os, glob
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
EXAMPLE
class MyDataset(Dataset):
def init(self, images_dir):
self.images_dir = images_dir
def __len__(self):
num = len(glob.glob(os.path.join(self.images_dir, '*.png')))
return num
def __getitem__(self, index):
image = read_image(self.images_dir[index])
return image
The __getitem__ logic looks wrong. images_dir is a string specifying the path to the directory of all images while you are trying to index it. Create a list of paths to all images which can be indexed in __getitem__.
Create the proper paths once in __init__ to avoid the overhead of calling f = glob.glob(os.path.join(self.images_dir, '*.png')) in each __getitem__ call and in the __len__ method.
Code duplication can yield to tricky errors and you are also paying overhead when each sample is loaded since you are re-creating all image paths over and over again.