How can I split dataset (cross validation) from torchvision.datasets by labels


#1

I want to use cifar 10 dataset from torchvision.datasets to do classification.

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

the detail in trainset is :

type(trainset[1]) # tuple
type(trainset[1][0]) # torch.Tensor (image (3,32,32))
type(trainset[1][1]) # int (label 0~9)

and then I would like to split train data to training data and validation data like 5-fold

how can I split dataset by labels (not split randomly) ?

i.e. 50000 image data, 10 labels, each label has 5000 images
cross validation (5-fold): 40000 training data, each label has 4000 images
10000 validation data, each label has 1000 images

Thanks in advance.