Mathplotlib imshow error

Hello, I’m completely new to Pytorch and I’m trying to display an MNIST image on Jupyter Notebook as below.

import torchvision
import matplotlib.pyplot as plt

train_dataset = torchvision.datasets.MNIST(root=’…/…/data/’,train=True, download=True)

img=train_dataset[0][0]
plt.imshow(img)

But the code is not working and I couldn’t figure out what is wrong.

<matplotlib.image.AxesImage at 0x1cf2856eda0>
Error in callback <function install_repl_displayhook..post_execute at 0x000001CF2580A378> (for post_execute):

AttributeError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\matplotlib\pyplot.py in post_execute()
148 def post_execute():
149 if matplotlib.is_interactive():
–> 150 draw_all()
151
152 # IPython >= 2

~\Anaconda3\lib\site-packages\matplotlib_pylab_helpers.py in draw_all(cls, force)
148 for f_mgr in cls.get_all_fig_managers():
149 if force or f_mgr.canvas.figure.stale:
–> 150 f_mgr.canvas.draw_idle()
151
152 atexit.register(Gcf.destroy_all)

~\Anaconda3\lib\site-packages\matplotlib\backend_bases.py in draw_idle(self, *args, **kwargs)
2059 if not self._is_idle_drawing:
2060 with self._idle_draw_cntx():
-> 2061 self.draw(*args, **kwargs)
2062
2063 def draw_cursor(self, event):

~\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py in draw(self)
428 # if toolbar:
429 # toolbar.set_cursor(cursors.WAIT)
–> 430 self.figure.draw(self.renderer)
431 finally:
432 # if toolbar:

~\Anaconda3\lib\site-packages\matplotlib\artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
—> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:

~\Anaconda3\lib\site-packages\matplotlib\figure.py in draw(self, renderer)
1297
1298 mimage._draw_list_compositing_images(
-> 1299 renderer, self, artists, self.suppressComposite)
1300
1301 renderer.close_group(‘figure’)

~\Anaconda3\lib\site-packages\matplotlib\image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
–> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together

~\Anaconda3\lib\site-packages\matplotlib\artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
—> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:

~\Anaconda3\lib\site-packages\matplotlib\axes_base.py in draw(self, renderer, inframe)
2435 renderer.stop_rasterizing()
2436
-> 2437 mimage._draw_list_compositing_images(renderer, self, artists)
2438
2439 renderer.close_group(‘axes’)

~\Anaconda3\lib\site-packages\matplotlib\image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
–> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together

~\Anaconda3\lib\site-packages\matplotlib\artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
—> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:

~\Anaconda3\lib\site-packages\matplotlib\image.py in draw(self, renderer, *args, **kwargs)
564 else:
565 im, l, b, trans = self.make_image(
–> 566 renderer, renderer.get_image_magnification())
567 if im is not None:
568 renderer.draw_image(gc, l, b, im)

~\Anaconda3\lib\site-packages\matplotlib\image.py in make_image(self, renderer, magnification, unsampled)
791 return self._make_image(
792 self._A, bbox, transformed_bbox, self.axes.bbox, magnification,
–> 793 unsampled=unsampled)
794
795 def _check_unsampled_image(self, renderer):

~\Anaconda3\lib\site-packages\matplotlib\image.py in _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
428
429 mask = np.empty(A.shape, dtype=np.float32)
–> 430 if A.mask.shape == A.shape:
431 # this is the case of a nontrivial mask
432 mask[:] = np.where(A.mask, np.float32(np.nan),

AttributeError: ‘numpy.ndarray’ object has no attribute ‘mask’

just use:

train_dataset[0][0].show()

if you insist on using plt.imshow, you 'll have to convert your image to numpy, so that’s another story. Something like:

plt.imshow(np.array(train_dataset[0][0]).squeeze()); plt.show()