How to show a image in jupyter notebook with pytorch easily?

like in itorch, I can use itorch.image to show a image in the notebook.


Yeah this frustrated me a lot too because it’s so easy with itorch… You have a few options with python but there’s not a stand-alone command that I’m aware of for displaying an image from a PyTorch Tensor. If you convert to a PIL image then you can just execute the Image variable in a cell and it will display the image. To load to PIL:

img ='path-to-image-file').convert('RGB')

Or to convert straight from a PyTorch Tensor:

to_pil = torchvision.transforms.ToPILImage()
img = to_pil(your-tensor)

Other than that, there’s the usual matplotlib using numpy, and if you’ve used cv2 there are cv2 equivalents. In python it’s kind of whatever you get comfortable with, although apparently opencv might be faster performance-wise.


this is how I display an image:

import torch as t
from torchvision.transforms import ToPILImage
from IPython.display import Image
to_img = ToPILImage()

# display tensor
a = t.Tensor(3, 64, 64).normal_()

# display imagefile
1 Like

you can look at my notebook here, but yea seems like no standard solution.

Basically, I do:

%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

Hi Soumith,

I would like to know how I can use your solution, if I am using iterator like this

dataiter = iter(trainloader)
images, labels =

and the trainloader is loading data using torchvision.datasets.ImageFolder

I am stuck on this. I have created my own dataset and I am following CIFAR10 example from pytorch tutorial.
I would like to train a AlexNet using my own dataset.

Thank you in advance.


I have same question…

This seems not working in an iteration. I have to wrap it in display.

1 Like

This is what I did, hope can be helpful for someone.

from PIL import Image
`image ='img_path.jpg').convert('RGB')

What I did: transfer the tensor to numpy array, reshape(if needed), and use plt.imshow():

import matplotlib.pyplot as plt
img_np_arr = img_tensor.numpy()   # transfer the pytorch tensor(img_tensor) to numpy array
img_np_arr.shape    # check shape before reshape if needed
img_np_arr_reshaped = img_np_arr.reshape(img_w, img_h)  # reshape to 2-dims image
plt.imshow(img_np_arr_reshaped, cmap='gray')   # when display the grayscale image

This worked flawlessly for displaying an image from CIFAR-100 dataset. Thank you!