Shows image with specific index from MNIST dataset

I am trying to show an image from MNIST dataset. I found in Keras is like the following. How can I do it in Pytorch?

import matplotlib.pyplot as plt

%matplotlib inline # Only use this if using iPython
image_index = 7777 # You may select anything up to 60,000
print(y_train[image_index]) # The label is 8
plt.imshow(x_train[image_index], cmap='Greys')

In case you didn’t use any transformations, you’ll get PIL.Images and can show them directly.
If you transformed the images into tensors, you could also use matplotlib to display the images.
Here are the two ways to show the images:

# As PIL.Image
dataset = datasets.MNIST(root='PATH')
x, _ = dataset[7777]
x.show() # x is a PIL.Image here

# As torch.Tensor
dataset = datasets.MNIST(
    root='PATH',
    transform=transforms.ToTensor()
)

x, _ = dataset[7777] # x is now a torch.Tensor
plt.imshow(x.numpy()[0], cmap='gray')
4 Likes

hi, first time writing in this forum. @ptrblck , I love your solutions, you’ve always got great ones, this one included. Here’s a question - who or what on earth does .numpy() belong to? I can’t find any documentation on it, and it works without numpy being imported, and most of all - what is it doing so that I can show my image with imshow? what is the 0th element?

Second part to this question: If I leave out cmap=‘gray’ I get this green and purple thing. What am I seeing there?

Thank you

1 Like

I’m glad you like it here. :wink:

You can find the documentations for tensor.numpy here.
It gives you a numpy array sharing the same storage as your tensor, which is really cheap and does not include a copy (at least if your tensor is already on the CPU).

PyTorch uses numpy internally as a dependency to provide this method, so you don’t need to manually import it.

The 0th element is used to get the “0th channel”, since dataset[7777] will return the sample as [channels, height, width]. In the case of MNIST images, we are using a single channel ([1, 28, 28]), so we just index it with 0. Otherwise matplotlib will be confused.

matplotlib uses a default colormap for single channel images (I think ‘viridis’ if I’m not mistaken).
This is supposed to give you better visual information about your matrix (smaller changes should be better visible than using a grayscale colormap).
However, since we are dealing with grayscale images, we can just force matplotlib to plot the matrix as we want. :slight_smile:

1 Like

Ah OK, so it does come from PyTorch. Great, I was wondering how it was working.

Also thank you for that bit of info about the colormap. I actually like the way it looks.

Thanks for always have a good answer!

1 Like

Hi @ptrblck. What is the purpose of the underscoe “_” in your code if you don’t mind?

_ is a convention used in python for storing values that are not used (aka throwaway variables). More specifically, in this case, dataset[7777] returns two values the first of which is assigned to x and the second is put into _ as it is not needed.

1 Like