Your image tensor is of shape [batch, color channel, w, h]. You then squashed the last three dimensions together, and then reshaped into [batch, w, h, color channel] so your last three dimensions are messed up. And that is why you are seeing 9x subimages.
I don’t know why you squash all last three dimensions together at
imgs=... line and at first line in
show_images. If you get rid of them, then a single np call in the loop
plt.imshow(np.transpose(img, (1, 2, 0))) will work.