How to enable the dataloader to sample from each class with equal probability

The dataloader utility in torch (courtesy of Soumith Chintala) allowed one to sample from each class with equal probability. I was wondering, if there is a straightforward approach to enable the same in pytorch dataloaders.


Yeah this is called stratified sampling… I actually implemented this in my third-party package torchsample as a sampler… it’s aplty named `StratifiedSampler’ [see here] ( Here’s an example of it in action as well. You can likely just copy this class and use it in torchvision as an argument to a DataLoader. Something like this:

y = torch.from_numpy(np.array([0, 0, 1, 1, 0, 0, 1, 1]))
sampler = StratifiedSampler(class_vector=y, batch_size=2)
# then pass this sampler as an argument to DataLoader

Let me know if you need help adapting it. It depends on scikit-learn unfortunately, because they have a ton of good samplers like that and I didn’t feel like reimplementing it.


Nice work @ncullen93, thanks !

This is extremely useful. Thanks a lot. I was looking for a code that selects each class randomly with equal probability, and then samples an instance from that class, again with equal probability. However, stratified sampling does the job well.

I will try using it in my code and let you know if I have any doubts. Thanks again.

I am trying to get balanced classes for a multi-classes classification task. I have tried to use with no success. I am using your implementation but i get this error: ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2. Here is my code :


train_set = SentimentDataset(file=TRAIN_DATA, word2idx=word2idx, tword2idx=tword2idx,
                             max_length=0, max_topic_length=0, topic_bs=True)
val_set = SentimentDataset(file=VAL_DATA, word2idx=word2idx, tword2idx=tword2idx,
                           max_length=0, max_topic_length=0, topic_bs=True)

_weights = 1 / torch.FloatTensor(train_set.weights) # [296, 3381, 12882, 12857, 1016]
_weights = _weights.view(1, 5)
_weights = _weights.double()

sampler = StratifiedSampler(_weights, BATCH_SIZE)

loader_train = DataLoader(train_set, batch_size=BATCH_SIZE,
                          shuffle=False, sampler=sampler, num_workers=4)

loader_val = DataLoader(val_set, batch_size=BATCH_SIZE,
                        shuffle=False, sampler=sampler, num_workers=4)

model = RNN(embeddings, num_classes=num_classes, **_hparams)

criterion = torch.nn.CrossEntropyLoss()
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters)


In your StratifiedSampler, why you calculate n-splits as the number of batches, while you only iterate the shuffle&split iterator once? To my knowledge, n-splits defines the K in K-fold cross validation, StratifiedShuffleSplit just ensure at each cross, the distribution follows the population statistics on the whole dataset? For me, it makes sense that if your StratifiedSampler use the n_split =1 since you always reconstruct the StratifiedShuffleSplit?

1 Like

I wrote a new implementation that feels a bit cleaner and can be used with the batch_sampler argument of DataLoader. It supports reproducibility with torch.manual_seed(seed) for shuffle mode.


import torch
from sklearn.model_selection import StratifiedKFold

class StratifiedBatchSampler:
    """Stratified batch sampling
    Provides equal representation of target classes in each batch
    def __init__(self, y, batch_size, shuffle=True):
        if torch.is_tensor(y):
            y = y.numpy()
        assert len(y.shape) == 1, 'label array must be 1D'
        n_batches = int(len(y) / batch_size)
        self.skf = StratifiedKFold(n_splits=n_batches, shuffle=shuffle)
        self.X = torch.randn(len(y),1).numpy()
        self.y = y
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            self.skf.random_state = torch.randint(0,int(1e8),size=()).item()
        for train_idx, test_idx in self.skf.split(self.X, self.y):
            yield test_idx

    def __len__(self):
        return len(self.y)

Usage example:

from import TensorDataset, DataLoader

X = torch.randn(100,20)
y = torch.randint(0,7,size=(100,))

data_loader = DataLoader(
    batch_sampler=StratifiedBatchSampler(y, batch_size=5)

Looks a good solution! Why providing the seed here instead at creation time?

I think there is a little correction to do though:

__len__ should return len(self.y) // batch_size no?

1 Like

An error was raised by this line of code y = y.numpy() while training on the GPU. y should be moved to the CPU before converting it to a numpy array.

You can replace y = y.numpy() with y = y.cpu().numpy()

1 Like

According to the documentation page for the dataloader, "len(dataloader) heuristic is based on the length of the sampler used." So yea if you want to access the number of batches, instead of the total samples, from the dataloader, I agree returning len(self.y) // batch_size, or for that matter self.n_batches from __init__, is a good idea.