Ah ok, sorry for the misunderstanding.
You could do the following.
First, let’s create an imbalanced CIFAR10 dataset. In the original CIFAR10 dataset each class has 5000 instances.
For simplicity let’s just use 500 instances of class0, 5000 instances of class1, 500 instance of class2, …
# Load CIFAR10
dataset = datasets.CIFAR10(
root='YOUR_PATH,
transform=transforms.ToTensor())
# Get all training targets and count the number of class instances
targets = np.array(dataset.train_labels)
classes, class_counts = np.unique(targets, return_counts=True)
nb_classes = len(classes)
print(class_counts)
# Create artificial imbalanced class counts
imbal_class_counts = [500, 5000] * 5
# Get class indices
class_indices = [np.where(targets == i)[0] for i in range(nb_classes)]
# Get imbalanced number of instances
imbal_class_indices = [class_idx[:class_count] for class_idx, class_count in zip(class_indices, imbal_class_counts)]
imbal_class_indices = np.hstack(imbal_class_indices)
# Set target and data to dataset
dataset.train_labels = targets[imbal_class_indices]
dataset.train_data = dataset.train_data[imbal_class_indices]
assert len(dataset.train_labels) == len(dataset.train_data)
Now that we have thrown out a lot of samples, let’s have a look at the training loop.
loader = DataLoader(
dataset, batch_size=64, shuffle=True)
# Here we have an imbalanced dataset
for batch_idx, (data, target) in enumerate(loader):
print('Batch {}, classes {}, count {}'.format(
batch_idx, *np.unique(target.numpy(), return_counts=True)))
# Your model will most likely perform bad, and will overfit on the
# majority classes
In this loop, the samples are imbalanced and your model will most likely overfit on the majority classes.
You can also resample in a more imbalanced way to see the effect more clearly.
Try to train your model here and have a look at the performance.
To counter the imbalanced dataset, let’s create a WeightedRandomSampler
, which draws the samples using probabilities (weights).
# Oversample the minority classes
targets = dataset.train_labels
class_count = np.unique(targets, return_counts=True)[1]
print(class_count)
weight = 1. / class_count
samples_weight = weight[targets]
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
Now we can create a training loop again and have a look at the sample distribution.
weighted_loader = DataLoader(
dataset, batch_size=64, sampler=sampler)
# Here we have a balanced dataset due to oversampling
for batch_idx, (data, target) in enumerate(weighted_loader):
print('Batch {}, classes {}, count {}'.format(
batch_idx, *np.unique(target.numpy(), return_counts=True)))
# Your model will propably perform better here
Now the samples are drawn in a more uniform way and your model will probably perform better now.
Try to train your model again and compare the accuracies.
Let me know, if you can use this example!