WeightedRandomSampler in PyTorch

I am implementing a multi class image classification model in PyTorch. However my data is not balanced, so I used the WeightedRandomSampler in PyTorch to create a custom dataloader. But here, I have one doubt that for eg. if my total number of classes is four time more than the batch size I am using to train the model (which is different from general scenario), if I would use replacement= True in WeightedRandomSampler, than wouldn’t it happen that in my loader I wouldn’t get images from some of the classes. for eg. The 3/4 th classes instances. As, replacement= True oversamples the instance by repeating.
Also, to effectively use WeightedRandomSampler with replacement= False, we have to control its total samples, so that we don’t iterate over entire dataset, is there anything like this that needs to take care of while using replacement= True

Also, to be clear about my dataset, it has around 2000 classes in long tail distribution and Batch size is 64

@ptrblck can you have a look at this.

Thanks

If you are using replacement=True and set the length in the sampler to the length of the dataset, some samples of the majority class won’t be drawn.
In practice this doesn’t really matter, since it’s the majority class so missing samples shouldn’t yield any significant issues.
You could of course increase the sampler length, which will thus draw more samples in the epoch, but the same behavior would be used, if you train your model just for more epochs.

1 Like

Hi @ptrblck, Thank you very much for your answer.
My concern is that as my total number of classes (~2000) is more than the batch size (~16) which is length of sampler, wouldn’t it happen that with ‘replacement=True’(over sampling) and iterating till the length of the dataset, there might be many classes whose instances didn’t actually come up in the whole epoch. I mean as length of sampler (batch size << number of classes) it might happen that only instances from 16 class come up while iterating for one epoch for every batch.
Also, overall for one epoch the instances from each category should be nearly balanced while iterating over till the length of loader.
Sorry for much confusion as my English is weak.
To give idea, about my dataset there are top 50 head classes whose instances are more than 1500 and tail 100 classes whose instances are less than 5.
Is there some way in which I could actually verify that instances from all classes came up in an balanced way(may not be exact) during training for every epoch while iterating for total steps = total instances / batch size

The length of the sampler is usually defined as the length of the dataset, not the batch size.
If you are dealing with 2000 classes, a batch size of 16, and e.g. 1 million samples, you could still use the WeightedRandomSampler approach to create batches “as if the dataset would be balanced”.
Of course a single batch won’t contain all 2000 classes, which is impossible due to batch_size < nb_classes.

You could count the drawn classes in each batch e.g. via torch.unique(target, return_counts=True).

4 Likes

Hi @ptrblck thank you very much for your answers, it was of much use

Thanks