Im new using pytorch. I have a custom dataset which is reading images from disk, it looks like this:
class SaffronDataset(Dataset):
def __init__(self, base_dir, transform=None, target_transform=None):
self.images_dir = os.path.join(base_dir, 'raw')
self.labels_dir = os.path.join(base_dir, 'segmentation')
self.transform = transform
self.target_transform = target_transform
self.images = sorted([os.path.join(self.images_dir,x) for x in os.listdir(self.images_dir)])
self.labels = sorted([os.path.join(self.labels_dir,x) for x in os.listdir(self.labels_dir)])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
images = Image.open(self.images[idx])
labels = Image.open(self.labels[idx])
if self.transform:
images = self.transform(images)
if self.target_transform:
labels = self.transform(labels)
return images, labels
I have a basic DataLoader too:
data_loader = DataLoader(data, batch_size=2, shuffle=True)
When i try to iterate the dataloader, im getting the error TypeError: list indices must be integers or slices, not list. The problem is that getitem() function is receiving a list instead of a single index and i dont understand why, anyone can help me with that?