Triplet data loader for cifar10

I see two good ways to do it.

  1. Write a custom dataset which will return triplets. Flatten & batch the triplets -> model -> reconstruct triplets -> loss
  2. Write a dataset that doesn’t directly returns the triplets. Dataset returns the samples -> model -> construct all possible triplets based on labels (all samples that is not from the same class can be viewed as a negative sample) -> loss. This gives you many triplets if your batch size is high.

A dataset class to customize

from torchvision import datasets

class MyCifar10(datasets.CIFAR10):
  def __init__(self, path, transforms, train=True):
    super().__init__(path, train, download=True)
    self.transforms = transforms

  def __getitem__(self, index):
    im, label = super().__getitem__(index)
    return self.transforms(im), label

Create triplets from batch (wont work in your case but might give you something to go from)

def create_triplets(originals, transformed):
  ''' Create input images, transformed images, and then negatives as both transformed negatives and non-transformed negatives '''
  batch_size = originals.size(0)
  n_repeat = (batch_size - 1) * 2

  anchors = originals.repeat_interleave(n_repeat, dim=0)
  positives = transformed.repeat_interleave(n_repeat, dim=0)

  mask = [i for i in range(batch_size**2) if i % (batch_size + 1) != 0]
  n1 = transformed.repeat(n_repeat + 1, 1)[mask]
  n2 = originals.repeat(n_repeat + 1, 1)[mask]
  negatives = torch.stack((n1, n2), dim=1).view(-1, anchors.size(1))

  return anchors, positives, negatives
1 Like