Triplet data loader for cifar10

hi all, could anyone give me some tips about how can I select triplet data in CIFAR10? (sample code)


How would you like to construct your triplets?

hi, actually I would like to have some triplet data for the triplet network. (such as (anchor, positive, negative) where anchor and positive come from the same class and negative comes from another class)

I see two good ways to do it.

  1. Write a custom dataset which will return triplets. Flatten & batch the triplets -> model -> reconstruct triplets -> loss
  2. Write a dataset that doesn’t directly returns the triplets. Dataset returns the samples -> model -> construct all possible triplets based on labels (all samples that is not from the same class can be viewed as a negative sample) -> loss. This gives you many triplets if your batch size is high.

A dataset class to customize

from torchvision import datasets

class MyCifar10(datasets.CIFAR10):
  def __init__(self, path, transforms, train=True):
    super().__init__(path, train, download=True)
    self.transforms = transforms

  def __getitem__(self, index):
    im, label = super().__getitem__(index)
    return self.transforms(im), label

Create triplets from batch (wont work in your case but might give you something to go from)

def create_triplets(originals, transformed):
  ''' Create input images, transformed images, and then negatives as both transformed negatives and non-transformed negatives '''
  batch_size = originals.size(0)
  n_repeat = (batch_size - 1) * 2

  anchors = originals.repeat_interleave(n_repeat, dim=0)
  positives = transformed.repeat_interleave(n_repeat, dim=0)

  mask = [i for i in range(batch_size**2) if i % (batch_size + 1) != 0]
  n1 = transformed.repeat(n_repeat + 1, 1)[mask]
  n2 = originals.repeat(n_repeat + 1, 1)[mask]
  negatives = torch.stack((n1, n2), dim=1).view(-1, anchors.size(1))

  return anchors, positives, negatives