Weighted Random Sampler still unbalanced

Hi everyone,
I have a custom data set for which I used torch.utils.data.Dataset. The getitem(self, index) method returns [torch.Tensor, label] for a specific index. The data is highly unbalanced.
For this I used the weighted random sampler and have the following function in order to obtain the weights :

def get_sample_weights(dataset_instance):
    label_dict = dataset_instance.labels
    classes = []
    for key in label_dict:
        idx = label_dict[key]
        classes.append(idx)

    class_counts = torch.tensor(classes)
    class_sample_count = torch.tensor([(class_counts == t).sum() for t in torch.unique(class_counts, sorted=True)])
    weight = 1. / class_sample_count.float()
    sample_weights = torch.tensor([weight[t] for t in class_counts])
    return sample_weights

Afterwards I use

test = CustomDataSet(...)
weights = get_sample_weights(test)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
dl = DataLoader(test, batch_size=100, collate_fn=stack_collate, sampler=sampler)

But the data I get from the DataLoader is still unbalanced. Do you guys have any idea what is wrong or have any tips how to solve this issue ?

Thank you in advance

Could you print the statistics of the labels in each batch and print the frequency of each label?

Hey,
sure!

print("The number of 0/1 labels: ",class_sample_count)
>>>The number of 0/1 labels:  tensor([ 1048, 10605])

Here is the number of 0 and 1 labels for a batchsize of 100.

dl = DataLoader(test, batch_size=100, collate_fn=stack_collate, sampler=sampler)


for i, dtlst in enumerate(dl):
    dt, target = dtlst
    if i % 10 == 0:
        print(f"For the {i}th iteration : 0/1 = {len(target[target == 0])}/{len(target[target == 1])}")
>>>For the 0th iteration : 0/1 = 14/86
>>>For the 10th iteration : 0/1 = 8/92
>>>For the 20th iteration : 0/1 = 6/94
>>>For the 30th iteration : 0/1 = 5/95
>>>For the 40th iteration : 0/1 = 9/91
>>>For the 50th iteration : 0/1 = 13/87
>>>For the 60th iteration : 0/1 = 10/90

Thanks. Could you post the class count in the original dataset?

I am nor quite sure what you mean by that. But my get_sample_weights(dataset_instance) takes the instance of my dataset class where labels is an attribute. Furthermore labels is a dictionary with an id as a key and the label 0 or 1 as the corresponding value.

I would like to know the original class distribution to check, if your current sampler is changing this distribution at all or not.
Assuming your original imbalance is 9:1, you could compare your code to this one (updated from my previous example for Python3):

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print('target train 0/1: {}/{}'.format(
    len(np.where(target == 0)[0]), len(np.where(target == 1)[0])))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i,
        len(np.where(target.numpy() == 0)[0]),
        len(np.where(target.numpy() == 1)[0])))

Which creates 1000 samples, where 900 belong to class0 and 100 to class1.
The output is:

target train 0/1: 900/100
batch index 0, 0/1: 46/54
batch index 1, 0/1: 52/48
batch index 2, 0/1: 49/51
batch index 3, 0/1: 40/60
batch index 4, 0/1: 40/60
batch index 5, 0/1: 54/46
batch index 6, 0/1: 52/48
batch index 7, 0/1: 56/44
batch index 8, 0/1: 61/39
batch index 9, 0/1: 48/52

which shows that each batch is approx. balanced now.

1 Like

Ah ok, so I create an instance of the dataset class. And call the labels attribute, which returns a dictionary with id’s and the label. Don’t mind the parameters.

test = CustomDataSet(number_params_baseline=10, los_baseline=0)
label_dict = test.labels
label_dict
>>>{253656: 1, 239289: 1, 282580: 1, 201668: 1, 210325: 1,...} # snippet of the actuall dictionary
len(label_dcit)
>>>11653

To get the original distribution :

classcount = np.zeros(2)
for key in label_dict.keys():
    idx = label_dict[key]
    classcount[idx] += 1
print(classcount)
>>>[ 1048. 10605.]

It seems that the sampler does not change the distribution.
I hope that helps.