Using data subsets

Hi everyone :slight_smile:

I have a question on how exactly torch.utils.data.Subset() works. From my understanding, it will return a subset of the data that consists of the same number of classes but only a subset of datapoints for each class. For example, data = torch.utils.data.Subset(trainset, range(0, len(trainset), 2)) would give me half of the entire data with half of the datapoints for each class, right?

If so, is there an easy way to use a subset of the classes as well? In other words, a way to only take half of the classes of a dataset and maybe only half of the datapoints of each of those classes?

Any help is very much appreciated!

All the best,
snowe

Hi,
Here is one way to do it:

from torchvision.datasets import CIFAR10
trainset = CIFAR10('./data', train=True, download=True)

# select classes you want to include in your subset
classes = torch.tensor([0, 1, 2, 3, 4])

# get indices that correspond to one of the selected classes
indices = (torch.tensor(trainset.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]

# subset the dataset
data = torch.utils.data.Subset(trainset, indices)

# or maybe half of the datapoints of each of those classes
data2 = torch.utils.data.Subset(trainset, indices[::2])

# a random half of the datapoints of each of those classes
total = len(data)
data3 = torch.utils.data.random_split(data, [total//2, total-total//2])[0]

Thank you @Sobir_Bobiev! :slight_smile: