Clarification on WeightedRandomSampler


so i have an unbalanced dataset of 2 classes. I searched for solutions and came across the WeightedRandomSampler which many people seem to use.

However i do not fully understand what it is used for: As far as i understand the WeightedRandomSampler makes sure (if correctly used), that in each Batch there is approximatly the same amount of samples for each Class.

Lets assume i have 80 samples of class “0” and 20 samples of class “1”.
I now use 10 Batches of size 10 and want to use the WeightedRandomSampler. That means i use ~50 samples from class “0” and class “1” meaning 30 of those (from class “1”) are duplicates, which leads to biased learning.

Now i can set “replacement” to false, But then im stuck with 20 Samples of class “1”. Meaning after 4 Batches (with each having 5 samples from class 1), the remaining Batches have no samples of class “1” at all. So im not sure how that is supposed to help at all.

The only sense i can make out of this feature is, to combine data augmentation and WeightedRandomSampler, to only augment data if it’s from class “1”, via random transforms. But that also means it is possible to have the same exact image of class “1” multiple times in my training leading to an at least slightly biased learning.

Am i missing something? Or do i missunderstand the functionality of the WeightedRandomSampler ?

1 Like

That’s one use case, but I wouldn’t say that it’s the only correct one. As the name suggests this sampler is used for weighting the sampling strategy, not only balancing it.

Yes, your description is correct. If you want to have balanced batches, you should not use replacement=False.

Yes, that might be the case, but could still perform better than the imbalanced dataset.
You would have to apply the sampler on your use case and compare the model performance metric to see how e.g. the confusion matrix changes and what the acceptable result would be.

1 Like

I have the same dilema. In my case, I will like to combine the WeightedRandomSampler with random data augmentation for the minority class. Is there any example of how to achieve this? @ptrblck

In the __getitem__ of your Dataset, you could check the current target value and apply (additional) transformations to the data for the minority classes.

1 Like