Selecting data from dataset by indices

I’m loading the MNIST dataset using torchvision like this:

train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(’/home/achennault/research/mnist/’, train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True, drop_last = True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(’/home/achennault/research/mnist/’, train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True, drop_last = True)

I have a list of indices of train_loader.dataset containing elements I want to keep. I want to selecct only these elements of the DataSet, then use the restricted DataSet in the DataLoader. Is there any way to do this?

You could wrap your Dataset into a Subset and pass the indices to only sample from these.

3 Likes

Thanks for your reply. I’ve created a Subset object using the relevant indices and the original dataset. However, I’m not sure how to use this in my original DataLoader. Do I simply overwrite the original dataset?

i.e. train_loader.dataset = mysubset

or is there something more complicated that I must do?

You can simply pass the Subset as the new dataset to the DataLoader:

my_subset = Subset(dataset, indices)
loader = DataLoader(my_subset)
8 Likes

This works! Note we have to specify batch_size if we want multiple indices

my_subset = Subset(dataset, indices)
loader = DataLoader(my_subset, batch_size=indices.size)

I am confused about the Subset() for torch.dataset. I have a list of indices and a pytorch dataset (e.g. cifar). When I used the indices to get a subset from the dataset, the new subset.dataset still keeps the same length as the original dataset, even though when it is loaded into a dataloader, the length becomes correct.

I would like to find out a solution to check the length of a subset, and how to iterate the subset.

Thank you in advance

The underlying .dataset will not be changed and will keep its original size.
You can check the length of the Subset via len(subset) and iterate it with for data in subset.

1 Like

May I ask how did you create the Subset object?

Nevermind, it’s torch.utils.data.Subset() .

Hi @ptrblck , I have a dataset defined, and I want to define a sampler that samples data points of batch size n such that the n indices are to be given by the user. How to achieve this?

I don’t fully understand your use case. Do you want the user to pass the batch indices in each iteration such that no sampling is used anymore or should these samples be somehow predefined by the user?