In your current approach it seems that you are using the class_counts
to create a WeightedRandomSampler
, while the each sample should get a weight as described in this post.
I’m also unsure, if each batch should contain at least n
samples from each class or the dataset splits.
In the former case, you could write a custom sampler (and remove the WeightedRandomSampler
) such that indices are samples using your condition.