How to use matplotlib to display images?

When i run this code in jupyter notebook, I want to show some pictures of my datasets:


and the output is:

Two questions:
(1) The original picture is RGB, but the output looks like gary-image:confused: Why?
(2) There is no loop in the output, but why every pitctures output nine times, and it seems output three times in every channel? Why?

Thank you~

Your image tensor is of shape [batch, color channel, w, h]. You then squashed the last three dimensions together, and then reshaped into [batch, w, h, color channel] so your last three dimensions are messed up. And that is why you are seeing 9x subimages.

I don’t know why you squash all last three dimensions together at imgs=... line and at first line in show_images. If you get rid of them, then a single np call in the loop plt.imshow(np.transpose(img, (1, 2, 0))) will work.

1 Like

Thank you, I get rid of imgs=… in show_images and not to reshape the tensor, then i got the right output. Thank you again~~