Select one kind of data from dataloader

Hi the PyTorch community !

I am working on the SVHN dataset and I would like to know if there is a simple way to get the sample of 1. For example in numpy I would do something like this : dataset[dataset[:, 1] == 1]. Does one of you know how I can manage that in PyTorch ?

Thank you :slight_smile:

I found a solution :

class FilteredDataset(
    def __init__(self, dataset, wanted_labels):
        self.parent = dataset
        indices = []
        for index, (img, lab) in enumerate(dataset):
            if lab == wanted_labels:
        self.indices = indices

    def __getitem__(self, index):
        return self.parent[self.indices[index]]

    def __len__(self):
        return len(self.indices)

one_label_dataset = FilteredDataset(dataset, 1)
one_label_dataloader =,
1 Like