@ptrblck I am wondering how can I add a condition to CustomDataset for data augmentation only for few specific input images for training (image_207, image_387, image_502, image_508, image_509, image_520, image_597)
.
This is the CustomDataset snippet, basically, I added self.transformm
to the previous code which posted above. I think I need to add a if condition
in __getitem__
to apply self.transformm
only on image. could you please point me in the right direction.
class CustomDataset(Dataset):
def __init__(self, image_paths, target_paths, transform_images):
self.image_paths = image_paths
self.target_paths = target_paths
#self.aug = aug
self.transformm = transforms.Compose([tf.rotate(10),
tf.affine(0.2,0.2)])
self.transform = transforms.ToTensor()
self.transform_images = transform_images
self.mapping = {
0: 0,
255: 1
}
def mask_to_class(self, mask):
for k in self.mapping:
mask[mask==k] = self.mapping[k]
return mask
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
mask = Image.open(self.target_paths[index])
t_image = image.convert('L')
t_image = self.transforms(t_image)
if any([img in image for img in transform_images]):
t_image = self.transformm(t_image)
mask = torch.from_numpy(numpy.array(mask, dtype=numpy.uint8))
mask = self.mask_to_class(mask)
mask = mask.long()
return t_image, mask, self.image_paths[index], self.target_paths[index]
def __len__(self): # return count of sample we have
return len(self.image_paths)