I am trying to implement a siamese network model using pytorch and pytorch lightning and I have a question about how to split such a dataset into training and validation dataset.
So, for a siamese network we create a triplet of input data point where we have an anchor image, a positive image (belonging to the same class as the anchor image) and a negative image (belonging to a different class). So, I am trying to use the pytorch dataset functionality as follows:
So, the dataset creation at the moment is as follows:
class SiameseTriplet(Dataset): def __init__(self, torchvision_dataset, transform=None): # Store things to class variables .... def __getitem__(self, index): img0_tuple = random.choice(self.dataset.imgs) # randomly sample to get a negative sample ... anchor_class_name = img0_tuple.split('/')[-2] all_files_in_class = glob.glob(self.dataset.root + anchor_class_name + '/*') # randomly sample now to get the positive image positive_image = random.choice(all_files_in_class) # Do some image transformations .... return anchor, positive, negative
Now what I basically want to do create a validation dataset. I am not sure what the best strategy is for that. Say I want to do a 95-5 split between training and validation. I thought the best way to do so would be split the dataset so that each class has the 95-5 split. Is that something reasonable to do? So say I have three classes and my data is organized as follows:
| ------- horse
In that case, I would go the 95 -5 split on each of these classes. Is that something reasonable to do for Siamese networks?
Now in my training code, I do something like:
import torchvision.datasets as dset train_folder_dataset = dset.ImageFolder(root=data_path) train_dataset = SiameseTriplet(image_folder_dataset=train_folder_dataset, transform=transform)
Now, it seems to me that the ideal place to split the dataset would be after the
train_folder_dataset = dset.ImageFolder(root=data_path) line. However, I cannot figure out how to do so with the strategy I discussed earlier where each of the class has the specified split.
I know there is a
torch.utils.data.Subset class for doing a split on ImageFolder but how do I tell it to apply the split per class as well?