I see two good ways to do it.
- Write a custom dataset which will return triplets. Flatten & batch the triplets -> model -> reconstruct triplets -> loss
- Write a dataset that doesn’t directly returns the triplets. Dataset returns the samples -> model -> construct all possible triplets based on labels (all samples that is not from the same class can be viewed as a negative sample) -> loss. This gives you many triplets if your batch size is high.
A dataset class to customize
from torchvision import datasets
class MyCifar10(datasets.CIFAR10):
def __init__(self, path, transforms, train=True):
super().__init__(path, train, download=True)
self.transforms = transforms
def __getitem__(self, index):
im, label = super().__getitem__(index)
return self.transforms(im), label
Create triplets from batch (wont work in your case but might give you something to go from)
def create_triplets(originals, transformed):
''' Create input images, transformed images, and then negatives as both transformed negatives and non-transformed negatives '''
batch_size = originals.size(0)
n_repeat = (batch_size - 1) * 2
anchors = originals.repeat_interleave(n_repeat, dim=0)
positives = transformed.repeat_interleave(n_repeat, dim=0)
mask = [i for i in range(batch_size**2) if i % (batch_size + 1) != 0]
n1 = transformed.repeat(n_repeat + 1, 1)[mask]
n2 = originals.repeat(n_repeat + 1, 1)[mask]
negatives = torch.stack((n1, n2), dim=1).view(-1, anchors.size(1))
return anchors, positives, negatives