Load the same number of data per class

Thanks, Mika, it works like a charm.
I had to change the while condition to ensure that the last mini-batch is also delivered to the data loader.

# in __iter__'s while condition, change "<" to "<=".
while self.count + self.batch_size <= len(self.dataset):
    # the rest of the code

In case, anyone prefers using a library for this task, there is a similar Sampler in PyTorch Metric Learning named MPerClassSampler. Refer to here

2 Likes