WeightedSampler by number of labels

Hi, what way would you suggest to create a version of Sampler which will allow me to control the distribution of samples in a batch by a number of their labels, talking about multi-label data ?

For example, in each batch the distribution will be 3 samples with 4 labels, 7 samples with 2 labels and 10 with one label.

Moreover in my case I would prefer to have, lets say with batch_size=20, each time a full batch of samples with same number of labels, but looping over all the batches over the whole variety of amounts of labels.

The easiest way i think about is to create, if I have 1-5 labels, 5 dataloaders, and using random.random() toss a number and according to different amount of labels prior distribution (lets say in a dataset of 10000 samples, where there are 2000 for each amount of labels between 1 and 5) I will firstly choose which dataloader to load from and then load the batch with the corresponding number of labels I tossed.

Any other creative suggestions ? Any built-in torchy way to do it ?