Data Sampler to Handle Class Imbalance

I trying to train a object detector on my custom data. Data I have has issue of class imbalance, meaning only 10% of images have objects in it, other 90% images doesn’t contain any objects. So here is what I want to do to solve imbalance issue,

At every epoch, augment those 10% images by some factor s.t. it’s qty. is same as background images (90%), now train model with these data.

For example, suppose I have total 100 images among which 10 contains objects and other 90 doesn’t contain any objects. At each epoch, randomly augment these 10 images and make it’s qty. 90. Now train the model on these 90 (augmented and contains objects) + 90 (background) = 180 images. Repeat these for upcoming epochs.

Is there a way in PyTorch to make such data loader and/or data sampler to handle aforementioned scenario?

If you want a predefined number of samples from each class, you could implement a custom sampler, which would yield the batch indices using your desired logic.
On the other hand, if your use case would allow weighted random sampling with replacement, you could use a WeightedRandomSampler as described here.