I have a custom dataset class where I apply some transforms.
My definition of the transforms are:
train_transform = {'img': transforms.Compose([transforms.Resize(Polyp.IMAGE_SIZE),
transforms.RandomVerticalFlip(),
# transforms.RandomHorizontalFlip(),
# transforms.RandomCrop(size=(320, 320), pad_if_needed=True),
# transforms.RandomResizedCrop(size=320,
# interpolation=Image.NEAREST),
transforms.RandomRotation(180),
transforms.ColorJitter(brightness=0.2,
contrast=0,
saturation=0,
hue=.5),
transforms.ToTensor()]),
'mask': transforms.Compose([transforms.Resize(Polyp.IMAGE_SIZE),
# transforms.RandomVerticalFlip(),
# transforms.RandomHorizontalFlip(),
# transforms.RandomCrop(size=(320, 320), pad_if_needed=True),
# transforms.RandomResizedCrop(size=320,
# interpolation=Image.NEAREST),
transforms.RandomRotation(180),
transforms.ToTensor()])}
In the __getitem__
method I do the following:
class SegDataset(Dataset):
def __init__(self, paths_df, transform=None, train_flag=True):
self.training_flag = train_flag
self.transform = transform
self.images = paths_df["img_path"].to_list()
self.segmentations = paths_df["seg_path"].to_list()
def __getitem__(self, idx):
image = Image.open(self.images[idx])
image = image.convert('RGB')
mask = Image.open(self.segmentations[idx])
mask = mask.convert('L')
mask = self.prepare_mask(mask)
seed = np.random.randint(sys.maxsize)
if self.transform is not None:
torch.manual_seed(seed) # apply this seed to img transforms
image = self.transform['img'](image)
torch.manual_seed(seed)
mask = self.transform['mask'](mask)
mask = mask[0, :, :].unsqueeze(0)
return image, mask
def __len__(self):
return len(self.images)
@staticmethod
def prepare_mask(mask):
mask = np.array(mask)
mask = np.where(mask > 0, 255, 0)
mask = mask[:, :, None].repeat(3, axis=2)
mask = Image.fromarray(mask.astype('uint8'))
return mask
Because my mask is a binary one, I expand it to 3 channels and then apply the transforms while taking only the first channel. But although I set the seed to the same value before applying each of the transforms the augmentation is different.
Note: I also tried with random.seed(seed) but still no success.
I might be missing something and would appreciate some help.