I have some images organized in folders as shown in the following picture:
In order to create a PyTorch DataLoader I defined a custom Dataset in this way
class CustomDataset(Dataset):
def __init__(self, root, dirs=None, transforms=None):
self.root_dir = root
self.sel_dirs = dirs
self.transforms = transforms
def __getFileNr(self, path):
count = 0
for f in os.listdir(path):
absPathIn = os.path.join(path, f)
if os.path.isdir(absPathIn):
count += self.__getFileNr(absPathIn)
elif os.path.isfile(absPathIn):
count += 1
return count
def __len__(self):
if self.sel_dirs == None:
return
total_count = 0
for dir in self.sel_dirs:
path = os.path.join(self.root_dir, dir)
if os.path.isfile(path):
total_count += 1
elif os.path.isdir(path):
count = 0
count += self.__getFileNr(path)
total_count += count
return total_count
def __checkDir(self, dirPath):
images = []
labels = []
for f in os.listdir(dirPath):
absDirPath = os.path.join(dirPath, f)
if os.path.isdir(absDirPath):
tmpImgs, tmpLabels = self.__checkDir(absDirPath)
images.extend(tmpImgs)
labels.extend(tmpLabels)
elif os.path.isfile(absDirPath):
images.append(absDirPath)
fileName = os.path.splitext(os.path.split(absDirPath)[1])[0]
labels.append(fileName.split('_')[-1])
return (images, labels)
def __getitem__(self):
images = []
labels = []
for dir in self.sel_dirs:
dirPath = os.path.join(self.root_dir, dir)
if os.path.isdir(dirPath):
tmpImgs, tmpLabels = self.__checkDir(dirPath)
images.extend(tmpImgs)
labels.extend(tmpLabels)
else:
return
if self.transforms and len(images) > 0:
images = self.transforms(images)
return images, labels
Everything seems to work, I’m new to PyTorch and I’m not sure this is a good solution, however when I try to create the DataLoader, I obtain the following error message:
I defined the function to create the custom Dataset as follow:
def get_dataset(path, dirs):
transforms = T.Compose([
T.ToPILImage(),
T.Resize((224,224)),
T.RandomHorizontalFlip(),
T.ToTensor()
])
dataset = CustomDataset(root=path, dirs=dirs, transforms=transforms)
return dataset
The name of every image is format filename_label.jpg
.
I understand the error, but I don’t know how to fix it.
Can you help me please?