Trying to build a dataloader

I have 4 folders which go like this
Train, Train_masks

Val, Val_masks

i am trying to build a dataloader to read this data so that i can feed it to a network,by looking a a few examples i have come up with this,

class DSB(Dataset):
def __init__(self, root, subset = 'train', transform = None):
    
    self.root = os.path.expanduser(root)
    self.transform = transform
    self.subset = subset
    self.data_path, self.label_path =[], []
    
    def load_images(path):
        images_dir = [os.path.join(path, file) for file in os.listdir(path) if os.path.isfile(os.path.join(path,file))]
        images_dir.sort()
        
        return images_dir
    
    if self.subset =='train':
        self.datapath = load_images(self.root+'train')
        self.label_path = load_images(self.root+'train_label')
        
    elif self.subset == 'val':
        self.datapath = load_images(self.root+'val')
        self.label_path = load_images(self.root+'val_label')
    else:
        raise RuntimeError('Invalid Dataset'+ self.subset + ', it must be one of:'
                                                             ' \'train\', \'val\'')
        
    def __getitem__(self,index):
        
        img = Image.open(self.data_path[index])
        target = Image.open(self.label_path[index]) if not self.subset == 'test' else None
        
        if self.transform is not None:
            img = self.transform(img)
            target = self.transform(target)
            
        return img, target
    
    def __len__(self):
        return len(self.data_path)


def im_show(img_list):

to_PIL = transforms.ToPILImage()
if len(img_list) > 9:
    raise Exception("len(img_list) must be smaller than 10")

for idx, img in enumerate(img_list):
    img = np.array(to_PIL(img))
    plt.subplot(100 + 10 * len(img_list) + (idx + 1))
    fig = plt.imshow(img)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

plt.show()

`

but when i run this ,

train_dataset = DSB(root='/media/ryan/da5df9e4-cdc6-4d55-91e8-b2383e89165f/dsbdata/' ,
                    subset="train",
                    transform=transforms.Compose([
                        transforms.Scale((256, 256)),
                        transforms.ToTensor()])
                    )

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=8,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=1)


img_list = []
for i in range(4):
    img, label = train_dataset[i]
    img_list.append(img)
    img_list.append(label)
im_show(img_list)

i get

    NotImplementedError                       Traceback (most recent call last)
<ipython-input-286-d95b0e86d99d> in <module>()
     16 img_list = []
     17 for i in range(4):
---> 18     img, label = train_dataset[i]
     19     img_list.append(img)
     20     img_list.append(label)

/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     11 
     12     def __getitem__(self, index):
---> 13         raise NotImplementedError
     14 
     15     def __len__(self):

NotImplementedError: 

Any suggestions on what i should try to change or any feedback is highly appreciated,

Thanks in advance

It seems the __getitem__(self, index) function is defined in __init__().
Maybe it’s just a formatting error while pasting the code in the forum, but could you please check it?
__init__, __getitem__ and __len__ should be defined in the class (on the same level).

1 Like

will check ,one minute

Thanks ,That solved it
But now i get ,

    IndexError                                Traceback (most recent call last)
<ipython-input-288-6c9fa04cc535> in <module>()
     15 img_list = []
     16 for i in range(4):
---> 17     img, label = train_dataset[i]
     18     img_list.append(img)
     19     img_list.append(label)

<ipython-input-287-6a4c2792d56a> in __getitem__(self, index)
     25 
     26     def __getitem__(self,index):
---> 27         img = Image.open(self.data_path[index])
     28         target = Image.open(self.label_path[index]) if not self.subset == 'test' else None
     29 

IndexError: list index out of range

any tips?

And also do you have any feedback on my dataloader?Am i on the right track or should i change somethihng?

You are defining self.data_path = [] at first. But then you are assigning your image paths to self.datapath. Note the missing underscore.

Will check this out

Thank you for your time

Besides the minor bugs, it looks fine! :slight_smile:

Thank you for your feedback
it was really helpful.