You could write a custom sampler
and could use the current implementations as the base class.
The sampler
is responsible to create the indices, which are then passed to the Dataset,__getitem__
.
You could thus use the target
tensor and create batches of indices using your custom sample logic.
In case you would like to use weighted sampling, you could use WeightedRandomSampler
instead.
Note however, that this sampler will not guarantee to sample a specific number of classes in each batch.