Loading MNIST from Pytorch

Hi Folks,
I have a query that how would you be able to load MNIST dataset from pytorch and filter out number 9 or 5 from it?
I am learning pytorch so I would appreciate if you can share the code with me.

torchvision MNIST Dataset

This is based on CIFAR10,

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

but you can change that to:
trainset=torchvision.dataset.MNIST(**KWARGS)

For filtering out a particular category, you may need to define a custom DataLoader:

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

and this is good to read, too:

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

I am following the below approach-
train_data = torchvision.datasets.MNIST("./", train=True, transform=transforms.ToTensor(), download=True)

test_data_xy = torchvision.datasets.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)

idx_train = train_data.train_labels!=5

idx_test= test_data_xy.train_labels!=5

train_data.train_labels = train_data .train_labels[idx_train]

test_data_xy.train_data = test_data_xy.train_data[idx_test]

Getting an error as-
AttributeError: can’t set attribute.

No clue what needs to be done

Note that torchvision.datasets.MNIST returns a class of type 'torchvision.datasets.mnist.MNIST', not an actual list you an iterate over. You then have to use a dataloader to access the individual data points.

This is the code provided in the example to load MNIST. Can you try this and see?

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

batch_size = 4
trainset = torchvision.datasets.MNIST(root='./data/', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data/', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
1 Like

Hi Karthik, the code works, but I am not sure if your code has removed digit 5 from the dataset?

Hi Ronish, no, I used all the classes available. As J_Johnson said above, you should checkout the link on datasets and dataloaders to create a custom one which excludes the 5 class.

The way you are doing it won’t work since train_data and test_data_xy are not just lists you can directly modify.

1 Like