Hi, I implemented the custom dataset in which both features and labels are images. I’m unable to read the masks it is showing assertion error. Can someone help with this please I’m stuck here please find my code below:
class DirDataset(Dataset):
def init(self, img_dir, mask_dir):
self.img_dir = img_dir
self.mask_dir = mask_dir
try:
self.ids = [s.split('.')[0] for s in os.listdir(self.img_dir)]
except FileNotFoundError:
self.ids = []
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
img_files = glob.glob(os.path.join(self.img_dir, idx+'.*'))
mask_files = glob.glob(os.path.join(self.mask_dir, idx+'_mask.*'))
assert len(img_files) == 1, f'{idx}: {img_files}'
assert len(mask_files) == 1, f'{idx}: {mask_files}'
# use Pillow's Image to read .gif mask
# https://answers.opencv.org/question/185929/how-to-read-gif-in-python/
img = Image.open(img_files[0])
mask = Image.open(mask_files[0])
assert img.size == mask.size, f'{img.shape} # {mask.shape}'
# img = self.preprocess(img)
# mask = self.preprocess(mask)
return torch.from_numpy(img).float(), \
torch.from_numpy(mask).float()