I am trying to implement a siamese network model using pytorch and pytorch lightning and I have a question about how to split such a dataset into training and validation dataset.
So, for a siamese network we create a triplet of input data point where we have an anchor image, a positive image (belonging to the same class as the anchor image) and a negative image (belonging to a different class). So, I am trying to use the pytorch dataset functionality as follows:
So, the dataset creation at the moment is as follows:
class SiameseTriplet(Dataset):
def __init__(self, torchvision_dataset, transform=None):
# Store things to class variables
....
def __getitem__(self, index):
img0_tuple = random.choice(self.dataset.imgs)
# randomly sample to get a negative sample
...
anchor_class_name = img0_tuple[0].split('/')[-2]
all_files_in_class = glob.glob(self.dataset.root + anchor_class_name + '/*')
# randomly sample now to get the positive image
positive_image = random.choice(all_files_in_class)
# Do some image transformations
....
return anchor, positive, negative
Now what I basically want to do create a validation dataset. I am not sure what the best strategy is for that. Say I want to do a 95-5 split between training and validation. I thought the best way to do so would be split the dataset so that each class has the 95-5 split. Is that something reasonable to do? So say I have three classes and my data is organized as follows:
|— root
|-------- cat
|-------- dog
| ------- horse
In that case, I would go the 95 -5 split on each of these classes. Is that something reasonable to do for Siamese networks?
Now in my training code, I do something like:
import torchvision.datasets as dset
train_folder_dataset = dset.ImageFolder(root=data_path)
train_dataset = SiameseTriplet(image_folder_dataset=train_folder_dataset,
transform=transform)
Now, it seems to me that the ideal place to split the dataset would be after the train_folder_dataset = dset.ImageFolder(root=data_path)
line. However, I cannot figure out how to do so with the strategy I discussed earlier where each of the class has the specified split.
I know there is a torch.utils.data.Subset
class for doing a split on ImageFolder but how do I tell it to apply the split per class as well?