Trying to build a dataloader

I have 4 folders which go like this
Train, Train_masks

Val, Val_masks

i am trying to build a dataloader to read this data so that i can feed it to a network,by looking a a few examples i have come up with this,

class DSB(Dataset):
def __init__(self, root, subset = 'train', transform = None):
    self.root = os.path.expanduser(root)
    self.transform = transform
    self.subset = subset
    self.data_path, self.label_path =[], []
    def load_images(path):
        images_dir = [os.path.join(path, file) for file in os.listdir(path) if os.path.isfile(os.path.join(path,file))]
        return images_dir
    if self.subset =='train':
        self.datapath = load_images(self.root+'train')
        self.label_path = load_images(self.root+'train_label')
    elif self.subset == 'val':
        self.datapath = load_images(self.root+'val')
        self.label_path = load_images(self.root+'val_label')
        raise RuntimeError('Invalid Dataset'+ self.subset + ', it must be one of:'
                                                             ' \'train\', \'val\'')
    def __getitem__(self,index):
        img =[index])
        target =[index]) if not self.subset == 'test' else None
        if self.transform is not None:
            img = self.transform(img)
            target = self.transform(target)
        return img, target
    def __len__(self):
        return len(self.data_path)

def im_show(img_list):

to_PIL = transforms.ToPILImage()
if len(img_list) > 9:
    raise Exception("len(img_list) must be smaller than 10")

for idx, img in enumerate(img_list):
    img = np.array(to_PIL(img))
    plt.subplot(100 + 10 * len(img_list) + (idx + 1))
    fig = plt.imshow(img)


but when i run this ,

train_dataset = DSB(root='/media/ryan/da5df9e4-cdc6-4d55-91e8-b2383e89165f/dsbdata/' ,
                        transforms.Scale((256, 256)),

train_loader =,

img_list = []
for i in range(4):
    img, label = train_dataset[i]

i get

    NotImplementedError                       Traceback (most recent call last)
<ipython-input-286-d95b0e86d99d> in <module>()
     16 img_list = []
     17 for i in range(4):
---> 18     img, label = train_dataset[i]
     19     img_list.append(img)
     20     img_list.append(label)

/usr/local/lib/python3.5/dist-packages/torch/utils/data/ in __getitem__(self, index)
     12     def __getitem__(self, index):
---> 13         raise NotImplementedError
     15     def __len__(self):


Any suggestions on what i should try to change or any feedback is highly appreciated,

Thanks in advance

It seems the __getitem__(self, index) function is defined in __init__().
Maybe it’s just a formatting error while pasting the code in the forum, but could you please check it?
__init__, __getitem__ and __len__ should be defined in the class (on the same level).

1 Like

will check ,one minute

Thanks ,That solved it
But now i get ,

    IndexError                                Traceback (most recent call last)
<ipython-input-288-6c9fa04cc535> in <module>()
     15 img_list = []
     16 for i in range(4):
---> 17     img, label = train_dataset[i]
     18     img_list.append(img)
     19     img_list.append(label)

<ipython-input-287-6a4c2792d56a> in __getitem__(self, index)
     26     def __getitem__(self,index):
---> 27         img =[index])
     28         target =[index]) if not self.subset == 'test' else None

IndexError: list index out of range

any tips?

And also do you have any feedback on my dataloader?Am i on the right track or should i change somethihng?

You are defining self.data_path = [] at first. But then you are assigning your image paths to self.datapath. Note the missing underscore.

Will check this out

Thank you for your time

Besides the minor bugs, it looks fine! :slight_smile:

Thank you for your feedback
it was really helpful.