Hi, in my work I would like to use both triplet loss and cross entropy loss together. My dataset consists of folders. Usually I can load the image and label in the following way:
I came across this problem recently with ImageFolder so I wrote a custom ImageFolder for triplet sampling. Doesn’t really support integration with Cross Entropy but I still hope it help someone.
class TripletImageFolder(torchvision.datasets.ImageFolder):
"""From the torchvision.datasets.ImageFolder it generates triplet samples, used in training. For testing we use normal image folder.
Note: a triplet is composed by a pair of matching images and one of different class.
"""
def __init__(self, *arg, **kw):
super(TripletImageFolder, self).__init__(*arg, **kw)
self.n_triplets = len(self.samples)
self.train_triplets = self.generate_triplets()
def generate_triplets(self):
labels = torch.Tensor(self.targets)
triplets = []
for x in np.arange(self.n_triplets):
idx = np.random.randint(0, labels.size(0))
idx_matches = np.where(labels.numpy() == labels[idx].numpy())[0]
idx_no_matches = np.where(labels.numpy() != labels[idx].numpy())[0]
idx_a, idx_p = np.random.choice(idx_matches, 2, replace=False)
idx_n = np.random.choice(idx_no_matches, 1)[0]
triplets.append([idx_a, idx_p, idx_n])
return np.array(triplets)
def set_triplets(self, triplets):
self.train_triplets = triplets
def __getitem__(self, index):
t = self.train_triplets[index]
path_a, _ = self.samples[t[0]]
path_p, _ = self.samples[t[1]]
path_n, _ = self.samples[t[2]]
img_a = self.loader(path_a)
img_p = self.loader(path_p)
img_n = self.loader(path_n)
if self.transform is not None:
img_a = self.transform(img_a)
img_p = self.transform(img_p)
img_n = self.transform(img_n)
return img_a, img_p, img_n
triplet_data = TripletImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
triplet_dataloader = data.DataLoader(triplet_data, batch_size=BATCH_SIZE, shuffle=True)
The triplets are constant after dataset initialization but if you want to resample triplets you can just do:
#A for loop iterate through your dataset
for img_a, img_p, img_n in triplet_dataloader:
emb_a = model(img_a)
emb_p = model(img_p)
emb_n = model(img_n)
loss = loss_fn(emb_a, emb_p, emb_n)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
#Resample triplets after 1 training cycle
triplets = triplet_dataloader.dataset.generate_triplets()
triplet_dataloader.dataset.set_triplets(triplets)