How to do a stratified split

Hello.

Sorry for my english, i am still learning and thanks you for help.

I have all my datas inside a torchvision.datasets.ImageFolder. The idea is split the data with stratified method. For that propoose, i am using torch.utils.data.SubsetRandomSampler of this way:

dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_transform)
targets = dataset.targets

Targets is a array of 0s and 1s (2-class classification) something like this:
[0, 0, 1, 1, 0, 1,…]

from sklearn.model_selection import train_test_split

train_idx, valid_idx= train_test_split(
np.arange(len(targets)),
test_size=0.2,
shuffle=True,
stratify=targets)

I got stratified indices for train(80%) and valid(20%) something like this:

train_idx = [12, 54, 23, 123, 80…]
valid_idx = [77, 09, 67, 122, 665…]

This indices indicates the position of my 0s and 1s in way that this are splitted in a balanced way.
Then is when i try to split my real data generating my train/valid loader like this:

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

But len(train_loader.dataset) return whole dataset instead of 80% of them. What am i doing wrong?

3 Likes

You are not doing anything wrong and the batches provided by train_loader and valid_loader should just use the provided indices from the corresponding sampler.
Since the underlying .dataset is not changed, the len will be the same as the original one.
However, during the creation of the batches, the samplers will use the provided indices to load the data from the internal Dataset.

Alternatively, you could also wrap your dataset in a Subset, which will then change the length.
Besides that both approaches should yield the same result.

2 Likes

Oh! thanks you! i thought that the len in both .dataset should be different , but what you say makes sense. Thanks!

Hi, if I want to wrap the dataset in a Subset, how should I rewrite the code ? Please help.