I know weighted sampler can solve imbalanced data problem.
However, I wonder is there a way to load exactly the same number of data per class ?
What I need now is, for example, a batch of 10 samples from class A, 10 from class B, 10 from class C, ETC…( I mean “not probablistically” but deterministically make sure to load 10 sample per class. )
I also want to know how to combine that solution with torch.utils.data.DataLoader
2 Likes
Eta_C
December 28, 2019, 7:42am
2
I had the same problem when I re-implement PFE , which needs 64 classes, which have 4 data in each class, as an input.
Here’s my solution (not very good but useful for me).
#class Dataset
def __init__(self):
self.label_to_index = {'A': [1, 3, 4], 'B': [0, 2, 5], 'C': [6, 7, 8]}
def __getitem__(self, item):
images = np.empty((1, *image_shape))
labels = np.empty((1,))
for i, (label, indices) in enumerate(self.label_to_index.items()):
index = np.random.choice(indices, 1, replacement=True)[0]
images[i, ...] = get_image(index)
labels[i] = get_label(index)
return torch.from_numpy(images), torch.from_numpy(labels)
As you can see, it returns (1 from A, 1 from B, 1 from C) as an item. Then I set batch_size = 10
.
I think, if the dataset is large enough, after a large number of iterations, random choice in each batch would be similar to shuffle the dataset then iterate over it
1 Like
Thank you! That’s sounds nice. I try it.
I’m also trying to implement a metric learning (NCA) as you did.
Finally, I made a custom BatchSampler.
I hope this code help someone like me.
Reference here
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler
class BalancedBatchSampler(BatchSampler):
"""
BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
Returns batches of size n_classes * n_samples
"""
def __init__(self, dataset, n_classes, n_samples):
loader = DataLoader(dataset)
self.labels_list = []
for _, label in loader:
self.labels_list.append(label)
self.labels = torch.LongTensor(self.labels_list)
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < len(self.dataset):
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return len(self.dataset) // self.batch_size
MNIST example:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
n_classes = 5
n_samples = 8
mnist_train = torchvision.datasets.MNIST(root="mnist/mnist_train", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),]))
balanced_batch_sampler = BalancedBatchSampler(mnist_train, n_classes, n_samples)
dataloader = torch.utils.data.DataLoader(mnist_train, batch_sampler=balanced_batch_sampler)
my_testiter = iter(dataloader)
images, target = my_testiter.next()
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
imshow(torchvision.utils.make_grid(images))
17 Likes
Teerath
(teerathkumar)
August 14, 2020, 4:09am
5
Thanks, I desperately needed this.
Thanks, Mika, it works like a charm.
I had to change the while condition to ensure that the last mini-batch is also delivered to the data loader.
# in __iter__'s while condition, change "<" to "<=".
while self.count + self.batch_size <= len(self.dataset):
# the rest of the code
In case, anyone prefers using a library for this task, there is a similar Sampler in PyTorch Metric Learning named MPerClassSampler. Refer to here
2 Likes