[resolved] Is there MNIST dataset in torchvision.datasets?

(Yunjey) #1

I want to create a PyTorch tutorial using MNIST data set.
In TensorFlow, there is a simple way to download, extract and load the MNIST data set as below.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/")
x_train = mnist.train.images    # numpy array 
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

Is there any simple way to handle this in PyTorch?

(Yunjey) #2

It seems to support MNIST data set in torchvision.datasets.
I was confused because PyTorch documentation does not specify MNIST.

(Ajay Talati) #3

Yes it already there - see here


and here,


The code looks something like this,

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                       transforms.Normalize((0.1307,), (0.3081,))
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.Normalize((0.1307,), (0.3081,))
    batch_size=args.batch_size, shuffle=True, **kwargs)

(Ajay Talati) #4

How do you subset the MNIST training data? It’s 60,000 images, how can you reduce it to say 2000?

Here’s the code

>>> from torchvision import datasets, transforms

>>> train_all_mnist = datasets.MNIST('../data', train=True, download=True,
...                    transform=transforms.Compose([
...                        transforms.ToTensor(),
...                        transforms.Normalize((0.1307,), (0.3081,))
...                    ]))
Files already downloaded
>>> train_all_mnist
<torchvision.datasets.mnist.MNIST object at 0x7f89a150cfd0>

How do I subset train_all_mnist ?

Or, alternatively I could just download it again, and hack this line to 2000,


It’s a bit ugly - anyone know a neater way to do this?

(Yunjey) #5

what is your purpose of subsetting the training dataset?

(Ajay Talati) #6

I’m interested in Omniglot, which is like an inverse, MNIST, lots of classes, each with a small number of examples.

Take a look, here

By the way - thank you for your tutorials - they are very clear and helpful to learn from.

Best regards,



omniglot is in this Pull Request:

(Ajay Talati) #8

Ha, fanck Q :smile:

Spent a hour hacking together my own loader - but this looks better!

Seems to be the easiest data set for experimenting with one-shot learning?

(Ajay Talati) #9

Whats the current best methodology for Omniglot? Who or what’s doing the best at the moment?


@pranv set the record on Omniglot recently with his paper:

Attentive Recurrent Comparators

(Ajay Talati) #11

Thanks for that :wink:

It look’s like the DRAW I implemented in Torch years ago :smile:, without the VAE, and decoder/generative canvas.

I though you might like this, implementation of a GAN on Omniglot,

Code for training a GAN on the Omniglot dataset using the network described in:
Task Specific Adversarial Cost Function

(Ritchie Ng) #12

Have you found a better way to do this?

(Ajay Talati) #13

Nope sorry - been totally snowed under the past couple of months - not had any time to work on it.

If you’re referring to the alternative cost functions for GANs I don’t think they make much difference?

If you’re referring to non Gaussian attention mechanisms for the DRAW encoder, I don’t know of any better approach than @pranav 's as mentioned above. I think he’s open sourced his code?



(Pranav Shyam) #14

The code for Attentive Recurrent Comparators is here: https://github.com/pranv/ARC

It includes Omniglot data downloading and iterating scripts along with all the models proposed in the paper (the nets are written and trained with Theano).

I will try to submit a PR for torchvision.datasets.Omniglot if I find some time :slight_smile: