Your previous code already sampled the data, so you can reuse it:
class Siamese(Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
# We need approx 50 % of images of the same class
same_class = random.randint(0, 1)
img_0, label_0 = self.dataset[index]
if same_class:
while True:
# keep looping till the same class image is found
index_1 = random.randint(0, self.__len__()-1)
img_1, label_1 = self.dataset[index_1]
if label_0 == label_1:
break
else:
while True:
index_1 = random.randint(0, self.__len__()-1)
img_1, label_1 = self.dataset[index_1]
if label_0 != label_1:
break
return (img_0, label_0), (img_1, label_1)
def __len__(self):
return len(self.dataset)