Weighted Random Sampler in PyTorch

I want to use weighted random sampler for train data in PyTorch for image segmentation. However, I am not sure how I should compute weights since each image contains multiple labels and whether the weights should be based on pixels or number of samples per each class?
Thank you in advance for your help.

You can weight based on the frequency that a single class exists across all of the training dataset. This means, that if a class exists three times in an image and once in another image, whereas another class exists twice in the first image, you will assign the weights based on the class distribution. class_rep = [4,2]. This is a simple but reasonable method to weight samples.

Another method to say if the object exists at least once in an image, then count this as a single sampling instance rather than X number of times it was repeated. It depends on your dataset.

1 Like

Thank you. Can it be more reasonable if we count the number of pixels per class?

Count the frequency of the label based on the area of the mask (counting the pixels) might give more accurate representation of frequency. However, it might create some bias toward object sizes.

This means an elephant in a single image might have the same frequency of appearing as an apple in 10 images.

It depends on your training procedure. You might try both of them and see what fits your task.