How to enable the dataloader to sample from each class with equal probability

Yeah this is called stratified sampling… I actually implemented this in my third-party package torchsample as a sampler… it’s aplty named `StratifiedSampler’ [see here] (https://github.com/ncullen93/torchsample/blob/master/torchsample/samplers.py#L22). Here’s an example of it in action as well. You can likely just copy this class and use it in torchvision as an argument to a DataLoader. Something like this:

y = torch.from_numpy(np.array([0, 0, 1, 1, 0, 0, 1, 1]))
sampler = StratifiedSampler(class_vector=y, batch_size=2)
# then pass this sampler as an argument to DataLoader

Let me know if you need help adapting it. It depends on scikit-learn unfortunately, because they have a ton of good samplers like that and I didn’t feel like reimplementing it.

12 Likes