Error while filtering MNIST dataset

Hello!

I’m trying to filter MNIST dataset by using a desired label using the following code:

mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
idx = mnist_full.targets == self.labels
mnist_full = mnist_full[idx]

with self.labels=1

The expected behaviour would be to obtain a dataset only with 1’s, but I get the following error:

  File "/Work/vidanodet/./src/simple_GAN.py", line 55, in setup
    mnist_full = mnist_full[idx]
  File "/usr/local/lib/python3.9/dist-packages/torchvision/datasets/mnist.py", line 127, in __getitem__
    img, target = self.data[index], int(self.targets[index])
ValueError: only one element tensors can be converted to Python scalars

I’ve checked idx is a tensor with 1 dimension containing booleans. Any idea of the problem?

Thanks!

The error you are getting is because you are trying to cast a Tensor with many values into an int. If the result of self.targets[index] was a tensor with only one value then there would be no problem, but the python built-in function does not know what to do with a tensor with many values.

# Error
int(self.targets[index])

# This should work
self.targets[index].int()

¡Hello!

You’re right but that is torchvision code. I wasn’t accessing to the dataset inner data structures to index them correctly. Here the correct code:

mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
idx = mnist_full.targets == self.labels
mnist_full.data = mnist_full.data[idx]
mnist_full.targets = mnist_full.targets[idx]