I’m implementing a Siamese network. Herefor I always need two images, which should be randomly sampled with p=0.5 as both from the same class and from different classes.
My idea is
class SiameseDataset(MyOwnDataset):
# Source: https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese
# -networks-medium.ipynb
def __init__(self, args, file, augumentation=None):
super().__init__(args, file, augumentation)
def __getitem__(self, index):
# We need approx 50 % of images of the same class
same_class = random.randint(0, 1)
img_0 = self.data[index]
label_0 = self.labels[index]
if same_class:
while True:
# keep looping till the same class image is found
index_1 = random.randint(0, self.__len__())
label_1 = self.labels[index_1]
if label_0 == label_1:
img_1 = self.data[index_1]
break
else:
while True:
index_1 = random.randint(0, self.__len__())
label_1 = self.labels[index_1]
if label_0 != label_1:
img_1 = self.data[index_1]
break
return (img_0, label_0), (img_1, label_1)
I face now the issue that MyOwnDataset is a ConcatDataset out of 6 Datasets belonging all to different classes. So when I need a sample from a different dataset, this does not work as self.data[index] belongs to samples only from the same class.
Hi, unfortunately, the solution is not efficient since it searches for positive/negative pairs randomly. This takes a long time when the dataset is large. To improve the speed for finding pairs, you can use a dictionary storing the correspondences.
class ContrastiveDataset(Dataset):
def __init__(self, dataset, positive_prob=0.5):
super().__init__()
self.dataset = dataset
self.positive_prob = positive_prob
self.h = {}
# construct the hash table for the correspondenses
for i, im in enumerate(self.dataset.imgs):
# im[0] is the image address, and im[1] is the label
lbl = im[1]
if lbl in self.h:
self.h[lbl].append(i)
else:
self.h[lbl] = [i]
def __getitem__(self, index):
same_class = random.uniform(0, 1)
same_class = same_class > self.positive_prob
img_0, label_0 = self.dataset[index]
class_samples = self.h[label_0]
if same_class:
while True:
rnd_idx = random.randint(0, len(class_samples) - 1)
index_1 = class_samples[rnd_idx]
if index_1 != index:
img_1, label_1 = self.dataset[index_1]
break
else:
while True:
index_1 = random.randint(0, self.__len__() - 1)
if index_1 != index:
img_1, label_1 = self.dataset[index_1]
break
return (img_0, label_0), (img_1, label_1)
def __len__(self):
return len(self.dataset)