Select one kind of data from dataloader

Hi the PyTorch community !

I am working on the SVHN dataset and I would like to know if there is a simple way to get the sample of 1. For example in numpy I would do something like this : dataset[dataset[:, 1] == 1]. Does one of you know how I can manage that in PyTorch ?

Thank you :slight_smile:

I found a solution :

import torch.utils.data
class FilteredDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, dataset, wanted_labels):
        self.parent = dataset
        indices = []
        for index, (img, lab) in enumerate(dataset):
            if lab == wanted_labels:
                indices.append(index)        
        self.indices = indices

    def __getitem__(self, index):
        return self.parent[self.indices[index]]

    def __len__(self):
        return len(self.indices)

one_label_dataset = FilteredDataset(dataset, 1)
print(len(one))
one_label_dataloader = torch.utils.data.DataLoader(dataset=one_label_dataset,
                                                  batch_size=128, 
                                                  shuffle=True)
1 Like