Dataloader not showing images

I am trying to see the images being loaded in the data loader. I used PIL in a custom loader to load the images. However, when I am using the following code to display images, I face the error:

images, labels = next(iter(test_loader))
plt.imshow( images )
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-35-d5057d4c4d54> in <module>()
      1 images, labels = next(iter(test_loader))
----> 2 plt.imshow( images )

5 frames
/usr/local/lib/python3.7/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (4, 3, 224, 224) for image data

What’s the reason behind this error?

matplotlib,pyplot.imshow expects a numpy array in a valid image shape (e.g. [height, width, channels] or [height, width]), while you are passing tensors in the shape [batch_size, channels, height, width] to it.
You would thus have to index each image (via image = images[0]) and permute the axes to match the expected shape:

image = images[0].permute(1, 2, 0).numpy()
plt.imshow(image)
1 Like

I am trying to display the image as:

figure = plt.figure()
num_of_images = 60
for index in range(1, num_of_images + 1):
    plt.subplot(6, 10, index)
    plt.axis('off')
    plt.imshow(images[index].numpy().transpose(1,2,0), cmap='gray_r')

I am facing the error:

TypeError                                 Traceback (most recent call last)
<ipython-input-35-f492e5cf63ef> in <module>()
      4     plt.subplot(6, 10, index)
      5     plt.axis('off')
----> 6     plt.imshow(images[index].numpy().transpose(1,2,0), cmap='gray_r')

5 frames
/usr/local/lib/python3.7/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (28, 28, 1) for image data

The above format (28,28,1) is compliant with plt.show, then why I am facing this error?

Maybe matplotlib used more strict checks and doesn’t like the additional channel dimension in your version, so you could remove this dimension, if needed.
It also works in my setup using matplotlib==3.4.2:

plt.imshow(np.random.randint(0, 255, (24, 24, 1)).astype(np.uint8))