Data loader for Triplet loss + cross entropy loss

Hi, in my work I would like to use both triplet loss and cross entropy loss together. My dataset consists of folders. Usually I can load the image and label in the following way:

transform_train = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
                                      transforms.ColorJitter(brightness=1, contrast=1, saturation=1),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])


transform = transforms.Compose([transforms.Resize((224,224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])

train_dataset = datasets.ImageFolder('.././data/flower-photos/train', transform=transform_train)
val_dataset = datasets.ImageFolder('.././data/flower-photos/test', transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle=False)

I need to return anchor_img, positive_img, negative_img, anchor_label. How can I do that? Thanks in advanced.

1 Like

You could create a custom Dataset as described e.g. here.

I came across this problem recently with ImageFolder so I wrote a custom ImageFolder for triplet sampling. Doesn’t really support integration with Cross Entropy but I still hope it help someone.

class TripletImageFolder(torchvision.datasets.ImageFolder):
    """From the torchvision.datasets.ImageFolder it generates triplet samples, used in training. For testing we use normal image folder.
    Note: a triplet is composed by a pair of matching images and one of different class.
    """
    def __init__(self, *arg, **kw):
        super(TripletImageFolder, self).__init__(*arg, **kw)

        self.n_triplets = len(self.samples)
        self.train_triplets = self.generate_triplets()

    def generate_triplets(self):
        labels = torch.Tensor(self.targets)
        triplets = []
        for x in np.arange(self.n_triplets):
            idx = np.random.randint(0, labels.size(0))
            idx_matches = np.where(labels.numpy() == labels[idx].numpy())[0]
            idx_no_matches = np.where(labels.numpy() != labels[idx].numpy())[0]
            idx_a, idx_p = np.random.choice(idx_matches, 2, replace=False)
            idx_n = np.random.choice(idx_no_matches, 1)[0]
            triplets.append([idx_a, idx_p, idx_n])
        return np.array(triplets)

    def set_triplets(self, triplets):
        self.train_triplets = triplets

    def __getitem__(self, index):
        t = self.train_triplets[index]

        path_a, _ = self.samples[t[0]]
        path_p, _ = self.samples[t[1]]
        path_n, _ = self.samples[t[2]]

        img_a = self.loader(path_a)
        img_p = self.loader(path_p)
        img_n = self.loader(path_n)

        if self.transform is not None:
            img_a = self.transform(img_a)
            img_p = self.transform(img_p)
            img_n = self.transform(img_n)

        return img_a, img_p, img_n

triplet_data = TripletImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
triplet_dataloader  = data.DataLoader(triplet_data, batch_size=BATCH_SIZE, shuffle=True)

The triplets are constant after dataset initialization but if you want to resample triplets you can just do:

#A for loop iterate through your dataset
for img_a, img_p, img_n in triplet_dataloader:

        emb_a = model(img_a)
        emb_p = model(img_p)
        emb_n = model(img_n)
        loss = loss_fn(emb_a, emb_p, emb_n)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#Resample triplets after 1 training cycle 
triplets = triplet_dataloader.dataset.generate_triplets()
triplet_dataloader.dataset.set_triplets(triplets)