How to extract a certain subset from a torchvision dataset?

I’m using a torchvision dataset (FMNIST), and I’d like to extract a subset from the training portion, and use it as my validation dataset. For example, I’d like to extract half of class 1 instances from the training data, and use them only in my validation data.

How can I do it?

The dataset’s __getitem__ returns a tuple of (image, target) where the target is index of the target class. You should be able to iterate through the dataset and identify the indices that you want to keep/discard based on your criteria.

One way to do this is:

  1. Iterate through the dataset, find the indices for you want for training and the indices you want for validation
  2. Set train_ds = torch.utils.data.Subset(dataset, train_indices), and val_ds = torch.utils.data.Subset(dataset, val_indices)

Thanks. I didn’t know of data.Subset.