Unable to read the masks from the getitem() in the custom dataset. Please help!

Hi, I implemented the custom dataset in which both features and labels are images. I’m unable to read the masks it is showing assertion error. Can someone help with this please I’m stuck here please find my code below:

class DirDataset(Dataset):
def init(self, img_dir, mask_dir):
self.img_dir = img_dir
self.mask_dir = mask_dir

    try:
        self.ids = [s.split('.')[0] for s in os.listdir(self.img_dir)]
    except FileNotFoundError:
        self.ids = []

def __len__(self):
    return len(self.ids)

def __getitem__(self, i):
    idx = self.ids[i]
    img_files = glob.glob(os.path.join(self.img_dir, idx+'.*'))
    mask_files = glob.glob(os.path.join(self.mask_dir, idx+'_mask.*'))

    assert len(img_files) == 1, f'{idx}: {img_files}'
    assert len(mask_files) == 1, f'{idx}: {mask_files}'

    # use Pillow's Image to read .gif mask
    # https://answers.opencv.org/question/185929/how-to-read-gif-in-python/
    img = Image.open(img_files[0])
    mask = Image.open(mask_files[0])
    assert img.size == mask.size, f'{img.shape} # {mask.shape}'

    # img = self.preprocess(img)
    # mask = self.preprocess(mask)

    return torch.from_numpy(img).float(), \
        torch.from_numpy(mask).float()

What kind of error are you seeing?

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier :wink: