I have the following Custom dataset class for image segmentation task.
class LoadDataset(Dataset):
def __init__(self, img_dir, mask_dir, apply_transforms = None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transforms = apply_transforms
self.img_paths, self.mask_paths = self.__get_all_paths()
self.__pil_to_tensor = transforms.PILToTensor()
self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
self.__grayscale = transforms.Grayscale()
def __get_all_paths(self):
img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
img_paths = sorted(img_paths)
mask_paths = sorted(mask_paths)
return img_paths, mask_paths
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
img_tensor = self.__pil_to_tensor(img_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__pil_to_tensor(mask_PIL)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
if self.transforms:
img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
return img_tensor, mask_tensor
When I am applying the following transforms.RandomHorizontalFlip()
either the image or the mask is being flipped. But if the change the order of transformations in __getitem__
to the following then it works fine.
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
if self.transforms:
img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
img_tensor = self.__pil_to_tensor(img_PIL)
mask_tensor = self.__pil_to_tensor(mask_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
return img_tensor, mask_tensor
Does the order transformation matter somehow? I am using torchvision.transforms.v2
for all the transformations.