Dataloader Make Sure One sample at least for each class in each batch

Hi folks,

I am meeting one question when comparing the model performances between class-balance and instance-balance. Since the dataset is a little bit skewed, a class has way imbalanced and few samples than others. How can I set the weights to make sure at least one sample of this few-shot class can be batched each time while maintaining the other heavy classes still dominate the quantity in each batch.

In other words, some batches have 0 entities for this few-shot class that makes the loss computes as nan.

Thank you in advance.

Does writing a custom sampler that is guaranteed to sample from the smaller class every x samples work? Alternatively, a more crude but simpler solution is to just set weights for the desired balance you want and to simply throw out batches that have 0 entities for that class. Note that if you set shuffle=True this will still use all of your data over time, just not over every epoch.

Thanks for the note. I think the simpler one is surely working but I am not sure whether it would guarantee the smaller class will be drawn while the larger classes maintain the majority. UNLESS I set the smaller set significantly heavily, somehow make it close to the larger set. But in that case, it will be more like an instance-balance sampling than class-balance sampling. Do you happen to recommend any reference customized sampler sample codes that serve similar tasks ?