I am trying to see the images being loaded in the data loader. I used PIL in a custom loader to load the images. However, when I am using the following code to display images, I face the error:
images, labels = next(iter(test_loader))
plt.imshow( images )
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-35-d5057d4c4d54> in <module>()
1 images, labels = next(iter(test_loader))
----> 2 plt.imshow( images )
5 frames
/usr/local/lib/python3.7/dist-packages/matplotlib/image.py in set_data(self, A)
697 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
698 raise TypeError("Invalid shape {} for image data"
--> 699 .format(self._A.shape))
700
701 if self._A.ndim == 3:
TypeError: Invalid shape (4, 3, 224, 224) for image data
What’s the reason behind this error?
matplotlib,pyplot.imshow
expects a numpy array in a valid image shape (e.g. [height, width, channels]
or [height, width]
), while you are passing tensors in the shape [batch_size, channels, height, width]
to it.
You would thus have to index each image (via image = images[0]
) and permute
the axes to match the expected shape:
image = images[0].permute(1, 2, 0).numpy()
plt.imshow(image)
1 Like
I am trying to display the image as:
figure = plt.figure()
num_of_images = 60
for index in range(1, num_of_images + 1):
plt.subplot(6, 10, index)
plt.axis('off')
plt.imshow(images[index].numpy().transpose(1,2,0), cmap='gray_r')
I am facing the error:
TypeError Traceback (most recent call last)
<ipython-input-35-f492e5cf63ef> in <module>()
4 plt.subplot(6, 10, index)
5 plt.axis('off')
----> 6 plt.imshow(images[index].numpy().transpose(1,2,0), cmap='gray_r')
5 frames
/usr/local/lib/python3.7/dist-packages/matplotlib/image.py in set_data(self, A)
697 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
698 raise TypeError("Invalid shape {} for image data"
--> 699 .format(self._A.shape))
700
701 if self._A.ndim == 3:
TypeError: Invalid shape (28, 28, 1) for image data
The above format (28,28,1)
is compliant with plt.show, then why I am facing this error?
Maybe matplotlib
used more strict checks and doesn’t like the additional channel dimension in your version, so you could remove this dimension, if needed.
It also works in my setup using matplotlib==3.4.2
:
plt.imshow(np.random.randint(0, 255, (24, 24, 1)).astype(np.uint8))