Hi the PyTorch community !
I am working on the SVHN dataset and I would like to know if there is a simple way to get the sample of 1. For example in numpy I would do something like this : dataset[dataset[:, 1] == 1]. Does one of you know how I can manage that in PyTorch ?
Thank you
I found a solution :
import torch.utils.data class FilteredDataset(torch.utils.data.dataset.Dataset): def __init__(self, dataset, wanted_labels): self.parent = dataset indices = [] for index, (img, lab) in enumerate(dataset): if lab == wanted_labels: indices.append(index) self.indices = indices def __getitem__(self, index): return self.parent[self.indices[index]] def __len__(self): return len(self.indices) one_label_dataset = FilteredDataset(dataset, 1) print(len(one)) one_label_dataloader = torch.utils.data.DataLoader(dataset=one_label_dataset, batch_size=128, shuffle=True)