Use Weighted Random Sampler for Imbalanced Class

Hi,
I am trying to use WeightedRandomSampler in this way

class_sample_count = [39736,949, 7807]
weights = 1 / torch.Tensor(class_sample_count)
weights = weights.double()
sampler = torch.utils.data.sampler.WeightedRandomSampler(
weights=weights,
num_samples=?,
replacement=False)
dataloaders = {x: torch.utils.data.DataLoader(image_datasets, drop_last=True, sampler = sampler, batch_size=32) for x in [‘train’, ‘valid’]}

I want the minority class samples atleast once … what num_samples should I use… also am I using it the right way as I am seeing all the samples from only the minority class? Thank You in advance …

1 Like

The weights tensor should contain a weight for each sample, not the class weights.
Have a look at this post for an example.

2 Likes

Okay I understand,
Thanks again @ptrblck… Your Sample code in the link helped alot … :slight_smile:

1 Like

@ptrblck - I m using weighted random sampler for an imbalanced class problem. I have a doubt regarding usage of replacement parameter. Does passing it as False ensure that there are no repetitions of samples within the batch size. All my class samples are more than batch size.

replacement=False will not draw the same samples during the entire epoch, so while the first batches might have a balanced class distribution, the latter ones will yield more samples of the majority classes.
Thus, I don’t think using replacement=False is a proper way to balance the data batches.

@ptrblck Isn’t it important to state that with replacement=False, the sampler doesn’t consider weights as one might expect?

For example, I have a dataset with 1000 samples: 300 cats and 700 dogs. I assign weights as [1/300, 1/700] and set the dataset size to twice the size of the minority class (300*2=600). Using replacement=False, I would expect to get a balanced subset each epoch containing approximately 300 cats and 300 dogs. However, I don’t actually get a balanced subset.

While I can set replacement=True, it leads to duplicating samples within an epoch, which I want to avoid.

Using replacement=False allows you to undersample the dataset (you would need to stop iterating the dataset in this case) since the weights will still be used, but the sampler will be forced to pick imbalanced samples towards the end of the dataset. Without duplicating samples it will be impossible to create the same number of batches (as the full dataset will yield) in a balanced way.

After reviewing the source code, I understand why the imbalance occurs at the end of the dataset. This is due to the behavior of torch.multinomial.

Yes, that’s why I’m setting num_samples=2*num_minor_class. This theoretically allows for sampling balanced classes without duplicates. Initially, I thought this would happen based on the sampler’s documentation.

I recommend adding a note near the weights argument indicating that weights may not be used as expected when replacement=False.

I believe they are used as expected. However, when no replacement is allowed, the used sample cannot be picked anymore. Higher weights will still indicate a higher likelihood to be selcted.