I have a question on how exactly torch.utils.data.Subset() works. From my understanding, it will return a subset of the data that consists of the same number of classes but only a subset of datapoints for each class. For example, data = torch.utils.data.Subset(trainset, range(0, len(trainset), 2)) would give me half of the entire data with half of the datapoints for each class, right?
If so, is there an easy way to use a subset of the classes as well? In other words, a way to only take half of the classes of a dataset and maybe only half of the datapoints of each of those classes?
from torchvision.datasets import CIFAR10
trainset = CIFAR10('./data', train=True, download=True)
# select classes you want to include in your subset
classes = torch.tensor([0, 1, 2, 3, 4])
# get indices that correspond to one of the selected classes
indices = (torch.tensor(trainset.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]
# subset the dataset
data = torch.utils.data.Subset(trainset, indices)
# or maybe half of the datapoints of each of those classes
data2 = torch.utils.data.Subset(trainset, indices[::2])
# a random half of the datapoints of each of those classes
total = len(data)
data3 = torch.utils.data.random_split(data, [total//2, total-total//2])[0]