Thanks, Mika, it works like a charm.
I had to change the while condition to ensure that the last mini-batch is also delivered to the data loader.
# in __iter__'s while condition, change "<" to "<=".
while self.count + self.batch_size <= len(self.dataset):
# the rest of the code
In case, anyone prefers using a library for this task, there is a similar Sampler in PyTorch Metric Learning named MPerClassSampler. Refer to here