Dataloader for a Siamese Model with ConcatDataset

Hi everyone,

I’m implementing a Siamese network. Herefor I always need two images, which should be randomly sampled with p=0.5 as both from the same class and from different classes.

My idea is

class SiameseDataset(MyOwnDataset):
    # Source: https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese
    # -networks-medium.ipynb

    def __init__(self, args, file, augumentation=None):
        super().__init__(args, file, augumentation)

    def __getitem__(self, index):
        # We need approx 50 % of  images of the same class
        same_class = random.randint(0, 1)
        img_0 = self.data[index]
        label_0 = self.labels[index]
        if same_class:
            while True:
                # keep looping till the same class image is found
                index_1 = random.randint(0, self.__len__())
                label_1 = self.labels[index_1]

                if label_0 == label_1:
                    img_1 = self.data[index_1]
                    break
        else:
            while True:
                index_1 = random.randint(0, self.__len__())
                label_1 = self.labels[index_1]
                if label_0 != label_1:
                    img_1 = self.data[index_1]
                    break

        return (img_0, label_0), (img_1, label_1)

I face now the issue that MyOwnDataset is a ConcatDataset out of 6 Datasets belonging all to different classes. So when I need a sample from a different dataset, this does not work as self.data[index] belongs to samples only from the same class.

Do you have an idea how I could solve this?

Thanks and beste regards
Jonas

Your approach sounds reasonable.
I think you could change SiameseDataset a bit and just sample from the ConcatDataset as shown here:

class SiameseDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __getitem__(self, index):
        # We need approx 50 % of  images of the same class
        same_class = random.randint(0, 1)
        img_0, label_0 = self.dataset[index]
        if same_class:
            while True:
                # keep looping till the same class image is found
                index_1 = random.randint(0, self.__len__()-1)
                img_1, label_1 = self.dataset[index_1]

                if label_0 == label_1:
                    break
        else:
            while True:
                index_1 = random.randint(0, self.__len__()-1)
                img_1, label_1 = self.dataset[index_1]
                if label_0 != label_1:
                    break

        return (img_0, label_0), (img_1, label_1)

    def __len__(self):
        return len(self.dataset)

dataset1 = TensorDataset(
    torch.randn(10, 1),
    torch.zeros(10)
)
dataset2 = TensorDataset(
    torch.randn(10, 1),
    torch.ones(10)
)

concat_dataset = ConcatDataset((dataset1, dataset2))

dataset = SiameseDataset(concat_dataset)
(x0, y0), (x1, y1) = dataset[0]

Would that work for your use case?

Thanks a lot!
That solved my problem :slight_smile:

Have a nice day!

Hi, unfortunately, the solution is not efficient since it searches for positive/negative pairs randomly. This takes a long time when the dataset is large. To improve the speed for finding pairs, you can use a dictionary storing the correspondences.


class ContrastiveDataset(Dataset):

    def __init__(self, dataset, positive_prob=0.5):
        super().__init__()
        self.dataset = dataset
        self.positive_prob = positive_prob

        self.h = {}
        # construct the hash table for the correspondenses
        for i, im in enumerate(self.dataset.imgs):
            # im[0] is the image address, and im[1] is the label
            lbl = im[1]
            if lbl in self.h:
                self.h[lbl].append(i)
            else:
                self.h[lbl] = [i]

    def __getitem__(self, index):
        same_class = random.uniform(0, 1)
        same_class = same_class > self.positive_prob
        img_0, label_0 = self.dataset[index]

        class_samples = self.h[label_0]
        if same_class:
            while True:
                rnd_idx = random.randint(0, len(class_samples) - 1)
                index_1 = class_samples[rnd_idx]
                if index_1 != index:
                    img_1, label_1 = self.dataset[index_1]
                    break
        else:
            while True:
                index_1 = random.randint(0, self.__len__() - 1)
                if index_1 != index:
                    img_1, label_1 = self.dataset[index_1]
                    break

        return (img_0, label_0), (img_1, label_1)

    def __len__(self):
        return len(self.dataset)