Load the same number of data per class

I know weighted sampler can solve imbalanced data problem.
However, I wonder is there a way to load exactly the same number of data per class ?
What I need now is, for example, a batch of 10 samples from class A, 10 from class B, 10 from class C, ETC…( I mean “not probablistically” but deterministically make sure to load 10 sample per class. )

I also want to know how to combine that solution with torch.utils.data.DataLoader

I had the same problem when I re-implement PFE, which needs 64 classes, which have 4 data in each class, as an input.
Here’s my solution (not very good but useful for me).

#class Dataset
def __init__(self):
    self.label_to_index = {'A': [1, 3, 4], 'B': [0, 2, 5], 'C': [6, 7, 8]}

def __getitem__(self, item):
    images = np.empty((1, *image_shape))
    labels = np.empty((1,))
    for i, (label, indices) in enumerate(self.label_to_index.items()):
        index = np.random.choice(indices, 1, replacement=True)[0]
        images[i, ...] = get_image(index)
        labels[i] = get_label(index)
    return torch.from_numpy(images), torch.from_numpy(labels)

As you can see, it returns (1 from A, 1 from B, 1 from C) as an item. Then I set batch_size = 10.
I think, if the dataset is large enough, after a large number of iterations, random choice in each batch would be similar to shuffle the dataset then iterate over it

1 Like

Thank you! That’s sounds nice. I try it.
I’m also trying to implement a metric learning (NCA) as you did.

Finally, I made a custom BatchSampler.
I hope this code help someone like me.
Reference here

import numpy as np
import torch

from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler


class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset)
        self.labels_list = []
        for _, label in loader:
            self.labels_list.append(label)
        self.labels = torch.LongTensor(self.labels_list)
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

MNIST example:

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

n_classes = 5
n_samples = 8

mnist_train =  torchvision.datasets.MNIST(root="mnist/mnist_train", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),]))

balanced_batch_sampler = BalancedBatchSampler(mnist_train, n_classes, n_samples)

dataloader = torch.utils.data.DataLoader(mnist_train, batch_sampler=balanced_batch_sampler)
my_testiter = iter(dataloader)
images, target = my_testiter.next()


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

imshow(torchvision.utils.make_grid(images))
4 Likes

Thanks, I desperately needed this.