Using data subsets

Hi everyone :slight_smile:

I have a question on how exactly torch.utils.data.Subset() works. From my understanding, it will return a subset of the data that consists of the same number of classes but only a subset of datapoints for each class. For example, data = torch.utils.data.Subset(trainset, range(0, len(trainset), 2)) would give me half of the entire data with half of the datapoints for each class, right?

If so, is there an easy way to use a subset of the classes as well? In other words, a way to only take half of the classes of a dataset and maybe only half of the datapoints of each of those classes?

Any help is very much appreciated!

All the best,
snowe

1 Like

Hi,
Here is one way to do it:

from torchvision.datasets import CIFAR10
trainset = CIFAR10('./data', train=True, download=True)

# select classes you want to include in your subset
classes = torch.tensor([0, 1, 2, 3, 4])

# get indices that correspond to one of the selected classes
indices = (torch.tensor(trainset.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]

# subset the dataset
data = torch.utils.data.Subset(trainset, indices)

# or maybe half of the datapoints of each of those classes
data2 = torch.utils.data.Subset(trainset, indices[::2])

# a random half of the datapoints of each of those classes
total = len(data)
data3 = torch.utils.data.random_split(data, [total//2, total-total//2])[0]

3 Likes

Thank you @Sobir_Bobiev! :slight_smile: