Filtering a class form subset of FashonMNIST

Hi everybody,
I’m trying to learn how to use datasets form torchvision. Like written in the title, I want to filter a specific subset taken from FashonMNIST, dataset that I already splitted using random_split. the code that I written from now is this:

trainset_list = []
    trainloader_list = []

    # Define a transform to normalize the data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])

    # Download and load the training data
    trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)

    trainset_list =, [int(len(trainset) / x) for _ in range(x)])
    for i_trainset in trainset_list:
        trainloader_list.append(, batch_size=64, shuffle=True))

My idea was to write something like this after splitting the dataset in x, to remove class y from that portion of dataset:

    trainset_list[0].train_data = trainset_list[0].train_data[trainset_list[0].train_labels != y]
    trainset_list[0].train_labels = trainset_list[0].train_labels[trainset_list[0].train_labels != y]

Of course the code give me this error : ‘Subset’ object has no attribute ‘train_data’, because after splitting the dataset, each splitted part are now Subset.

How can I achieve the filtering?

Subset will hold the passed Dataset in its internal .dataset attribute, so trainset_list[0].dataset.train_data might work for you.

Unfortunately, in doing that I retrive again the whole dataset where the class is not nine, loosing the split i did before.

I wasn’t sure what your use case is and just focused on the error message.
Why would you like to remove specific indices from each split? Note that each Subset has its own indices and uses these to index the underlying Dataset directly.

My goal is to have just one split without one class, train a network and then create another network that will be trained with new data having one label more. I kinda succeeded in this creating my own class that mimic dataset, but I’m not sure if i trust the results of the training or not for now.
My point is: I want to be sure that batch of data are taken uniformly from the original dataset and then remove one label from just one split to train the first network.

If you know some other, easier, method to do it, I will be grateful :slight_smile:

I think in your use case it might be easier to manipulate the indices before creating the Subsets as shown in this small example:

from import Subset
from sklearn.model_selection import train_test_split
import torchvision.datasets as datasets

dataset = datasets.MNIST(

# Get all targets
targets = dataset.targets
# Create target_indices
target_indices = np.arange(len(targets))
# Split into train and validation
train_idx, val_idx = train_test_split(target_indices, train_size=0.8)

# Specify which class to remove from train
classidx_to_remove = 0
# Get indices to keep from train split
idx_to_keep = targets[train_idx]!=classidx_to_remove
# Only keep your desired classes
train_idx = train_idx[idx_to_keep]

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)

Let me know, if this would work for you.

Initially, I thought about that, but with the split of sklearn i can divide the initial dataset in just two part, or at least 2 split at the time from what I understood. I tried to apply this method using the split given by torch and have multiple split, but like before I have as a result subsets and so I could not access to the targets and data of that subset.

Thanks for the help, but for now I will continue using this custom class, that seems work fine

    def __init__(self, subset, label):
        self.subset = subset
        self.targets = [] = []

        indeces = subset.indices
        for index in indeces:
            if subset.dataset.__getitem__(int(index))[1] == label:

    def __len__(self):
        return len(

    def __getitem__(self, idx):
        img, target =[idx], int(self.targets[idx])
        return img, target
type or paste code here

Thanks for the answer, this works till fetching the subset. Training later raises error if the class labels don’t start from zero or aren’t sequential. Thank you.