I have 4 folders which go like this
Train, Train_masks
Val, Val_masks
i am trying to build a dataloader to read this data so that i can feed it to a network,by looking a a few examples i have come up with this,
class DSB(Dataset):
def __init__(self, root, subset = 'train', transform = None):
self.root = os.path.expanduser(root)
self.transform = transform
self.subset = subset
self.data_path, self.label_path =[], []
def load_images(path):
images_dir = [os.path.join(path, file) for file in os.listdir(path) if os.path.isfile(os.path.join(path,file))]
images_dir.sort()
return images_dir
if self.subset =='train':
self.datapath = load_images(self.root+'train')
self.label_path = load_images(self.root+'train_label')
elif self.subset == 'val':
self.datapath = load_images(self.root+'val')
self.label_path = load_images(self.root+'val_label')
else:
raise RuntimeError('Invalid Dataset'+ self.subset + ', it must be one of:'
' \'train\', \'val\'')
def __getitem__(self,index):
img = Image.open(self.data_path[index])
target = Image.open(self.label_path[index]) if not self.subset == 'test' else None
if self.transform is not None:
img = self.transform(img)
target = self.transform(target)
return img, target
def __len__(self):
return len(self.data_path)
def im_show(img_list):
to_PIL = transforms.ToPILImage()
if len(img_list) > 9:
raise Exception("len(img_list) must be smaller than 10")
for idx, img in enumerate(img_list):
img = np.array(to_PIL(img))
plt.subplot(100 + 10 * len(img_list) + (idx + 1))
fig = plt.imshow(img)
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.show()
`
but when i run this ,
train_dataset = DSB(root='/media/ryan/da5df9e4-cdc6-4d55-91e8-b2383e89165f/dsbdata/' ,
subset="train",
transform=transforms.Compose([
transforms.Scale((256, 256)),
transforms.ToTensor()])
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=8,
shuffle=True,
pin_memory=True,
num_workers=1)
img_list = []
for i in range(4):
img, label = train_dataset[i]
img_list.append(img)
img_list.append(label)
im_show(img_list)
i get
NotImplementedError Traceback (most recent call last)
<ipython-input-286-d95b0e86d99d> in <module>()
16 img_list = []
17 for i in range(4):
---> 18 img, label = train_dataset[i]
19 img_list.append(img)
20 img_list.append(label)
/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataset.py in __getitem__(self, index)
11
12 def __getitem__(self, index):
---> 13 raise NotImplementedError
14
15 def __len__(self):
NotImplementedError:
Any suggestions on what i should try to change or any feedback is highly appreciated,
Thanks in advance