Hi Folks,
I have a query that how would you be able to load MNIST dataset from pytorch and filter out number 9 or 5 from it?
I am learning pytorch so I would appreciate if you can share the code with me.
This is based on CIFAR10,
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
but you can change that to:
trainset=torchvision.dataset.MNIST(**KWARGS)
For filtering out a particular category, you may need to define a custom DataLoader:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
and this is good to read, too:
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
I am following the below approach-
train_data = torchvision.datasets.MNIST("./", train=True, transform=transforms.ToTensor(), download=True)
test_data_xy = torchvision.datasets.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)
idx_train = train_data.train_labels!=5
idx_test= test_data_xy.train_labels!=5
train_data.train_labels = train_data .train_labels[idx_train]
test_data_xy.train_data = test_data_xy.train_data[idx_test]
Getting an error as-
AttributeError: can’t set attribute.
No clue what needs to be done
Note that torchvision.datasets.MNIST
returns a class of type 'torchvision.datasets.mnist.MNIST'
, not an actual list you an iterate over. You then have to use a dataloader to access the individual data points.
This is the code provided in the example to load MNIST. Can you try this and see?
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
batch_size = 4
trainset = torchvision.datasets.MNIST(root='./data/', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data/', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
Hi Karthik, the code works, but I am not sure if your code has removed digit 5 from the dataset?
Hi Ronish, no, I used all the classes available. As J_Johnson said above, you should checkout the link on datasets and dataloaders to create a custom one which excludes the 5 class.
The way you are doing it won’t work since train_data
and test_data_xy
are not just lists you can directly modify.