[resolved] Is there MNIST dataset in torchvision.datasets?


(Yunjey) #1

I want to create a PyTorch tutorial using MNIST data set.
In TensorFlow, there is a simple way to download, extract and load the MNIST data set as below.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/")
x_train = mnist.train.images    # numpy array 
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

Is there any simple way to handle this in PyTorch?


(Yunjey) #2

It seems to support MNIST data set in torchvision.datasets.
I was confused because PyTorch documentation does not specify MNIST.


(Ajay Talati) #3

Yes it already there - see here

http://pytorch.org/docs/data.html

and here,

http://pytorch.org/docs/torchvision/datasets.html#mnist

The code looks something like this,

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)

(Ajay Talati) #4

How do you subset the MNIST training data? It’s 60,000 images, how can you reduce it to say 2000?

Here’s the code

>>> from torchvision import datasets, transforms


>>> 
>>> 
>>> train_all_mnist = datasets.MNIST('../data', train=True, download=True,
...                    transform=transforms.Compose([
...                        transforms.ToTensor(),
...                        transforms.Normalize((0.1307,), (0.3081,))
...                    ]))
Files already downloaded
>>> train_all_mnist
<torchvision.datasets.mnist.MNIST object at 0x7f89a150cfd0>

How do I subset train_all_mnist ?

Or, alternatively I could just download it again, and hack this line to 2000,

https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L64

It’s a bit ugly - anyone know a neater way to do this?


(Yunjey) #5

what is your purpose of subsetting the training dataset?


(Ajay Talati) #6

I’m interested in Omniglot, which is like an inverse, MNIST, lots of classes, each with a small number of examples.

Take a look, here

By the way - thank you for your tutorials - they are very clear and helpful to learn from.

Best regards,

Ajay


#7

omniglot is in this Pull Request:


(Ajay Talati) #8

Ha, fanck Q :smile:

Spent a hour hacking together my own loader - but this looks better!

Seems to be the easiest data set for experimenting with one-shot learning?


(Ajay Talati) #9

Whats the current best methodology for Omniglot? Who or what’s doing the best at the moment?


#10

@pranv set the record on Omniglot recently with his paper:

Attentive Recurrent Comparators
https://arxiv.org/abs/1703.00767


(Ajay Talati) #11

Thanks for that :wink:

It look’s like the DRAW I implemented in Torch years ago :smile:, without the VAE, and decoder/generative canvas.

I though you might like this, implementation of a GAN on Omniglot,

Code for training a GAN on the Omniglot dataset using the network described in:
Task Specific Adversarial Cost Function


(Ritchie Ng) #12

Have you found a better way to do this?


(Ajay Talati) #13

Nope sorry - been totally snowed under the past couple of months - not had any time to work on it.

If you’re referring to the alternative cost functions for GANs I don’t think they make much difference?

If you’re referring to non Gaussian attention mechanisms for the DRAW encoder, I don’t know of any better approach than @pranav 's as mentioned above. I think he’s open sourced his code?

Cheers,

Aj


(Pranav Shyam) #14

The code for Attentive Recurrent Comparators is here: https://github.com/pranv/ARC

It includes Omniglot data downloading and iterating scripts along with all the models proposed in the paper (the nets are written and trained with Theano).

I will try to submit a PR for torchvision.datasets.Omniglot if I find some time :slight_smile: