How to do a stratified split


Sorry for my english, i am still learning and thanks you for help.

I have all my datas inside a torchvision.datasets.ImageFolder. The idea is split the data with stratified method. For that propoose, i am using of this way:

dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
targets = dataset.targets

Targets is a array of 0s and 1s (2-class classification) something like this:
[0, 0, 1, 1, 0, 1,…]

from sklearn.model_selection import train_test_split

train_idx, valid_idx= train_test_split(

I got stratified indices for train(80%) and valid(20%) something like this:

train_idx = [12, 54, 23, 123, 80…]
valid_idx = [77, 09, 67, 122, 665…]

This indices indicates the position of my 0s and 1s in way that this are splitted in a balanced way.
Then is when i try to split my real data generating my train/valid loader like this:

train_sampler =
valid_sampler =

train_loader =, batch_size=batch_size, sampler=train_sampler)
valid_loader =, batch_size=batch_size, sampler=valid_sampler)

But len(train_loader.dataset) return whole dataset instead of 80% of them. What am i doing wrong?


You are not doing anything wrong and the batches provided by train_loader and valid_loader should just use the provided indices from the corresponding sampler.
Since the underlying .dataset is not changed, the len will be the same as the original one.
However, during the creation of the batches, the samplers will use the provided indices to load the data from the internal Dataset.

Alternatively, you could also wrap your dataset in a Subset, which will then change the length.
Besides that both approaches should yield the same result.


Oh! thanks you! i thought that the len in both .dataset should be different , but what you say makes sense. Thanks!

Hi, if I want to wrap the dataset in a Subset, how should I rewrite the code ? Please help.

working with Subset is fairly easy. You pass the dataset and the indices to Subset. The benefit of using Subset over Sampler is that you can set shuffle = True in the DataLoader

# Stratified Sampling for train and val
train_idx, validation_idx = train_test_split(np.arange(len(train_data)),

# Subset dataset for train and val
train_dataset = Subset(train_data, train_idx)
validation_dataset = Subset(train_data, validation_idx)

# Dataloader for train and val
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
1 Like