Using data subsets

Hi everyone :slight_smile:

I have a question on how exactly works. From my understanding, it will return a subset of the data that consists of the same number of classes but only a subset of datapoints for each class. For example, data =, range(0, len(trainset), 2)) would give me half of the entire data with half of the datapoints for each class, right?

If so, is there an easy way to use a subset of the classes as well? In other words, a way to only take half of the classes of a dataset and maybe only half of the datapoints of each of those classes?

Any help is very much appreciated!

All the best,

Here is one way to do it:

from torchvision.datasets import CIFAR10
trainset = CIFAR10('./data', train=True, download=True)

# select classes you want to include in your subset
classes = torch.tensor([0, 1, 2, 3, 4])

# get indices that correspond to one of the selected classes
indices = (torch.tensor(trainset.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]

# subset the dataset
data =, indices)

# or maybe half of the datapoints of each of those classes
data2 =, indices[::2])

# a random half of the datapoints of each of those classes
total = len(data)
data3 =, [total//2, total-total//2])[0]

Thank you @Sobir_Bobiev! :slight_smile: