I wrote a function that returns a balanced sampler for SubsetRandomSampler
, which can be used as as a sampler in Dataloder
(s). The function is working well and might be useful for others. However, I wonder if anyone has any comments, suggestions or improvements.
import numpy as np
import torch
def get_a_balanced_sampler(score, sampler_size, labels, no_classes):
'''
Args in -
score: posteriori values in the range [0 to 1] for each of the labels,
a value 1 indicates that the label has high likeliness to be of a correct
class
sampler_size: the intended sampler size the user wants to get back, the
size of the returned sampler will be (slightly) less than this, depending on
the minimum number of labels per class
labels: an array containing the labels
no_classes: the number of classes in the problem
Parameters -
percentage_of_selected_samples: selecting 50% for the samples with the highest
'score' values, thus, selection will be made randomly from these samples in a
balanced manner. The 50% can be changed according to the user requirements,
e.g., using less or higher values.
'''
percentage_of_selected_samples = 50/100
len_labels_per_class = np.zeros(no_classes, dtype=int)
idx_per_class = np.zeros([no_classes, len(labels)], dtype=int)
for i in range(no_classes):
idx_per_class[i] = labels==i
len_labels_per_class[i] = sum(idx_per_class[i] == True)
no_labels_per_class = min(len_labels_per_class)
sampler_pool_size = int(no_labels_per_class * percentage_of_selected_samples)
sampler_size = int(sampler_size/no_classes)
if(sampler_size > sampler_pool_size):
print('You need to decrease the value percentage_of_selected_samples: ', percentage_of_selected_samples)
exit('Exiting functget_the_samplerpler(): sampler_size has become larger than sampler_pool_size')
my_sampler = []
for i in range(no_classes):
sample_idx = (-score[idx_per_class[i]]).argsort()[:sampler_pool_size]
sample_idx = np.random.permutation(sample_idx)
sample_idx = sample_idx[:sampler_size]
my_sampler.extend(sample_idx)
if len(my_sampler) < 100: exit('Exiting function get_a_balanced_sampler(): small sampler has been geneated')
my_sampler = torch.utils.data.sampler.SubsetRandomSampler(my_sampler)
return my_sampler