Invalid image shape when using plt.imshow

I am following the Pytorch tutorial on the DCgan, and was trying to view the output images from the model individually, instead as grid. This is the code used:

generating new images

sets a random seed

manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

fake = gen(torch.randn(1, n_z, 1, 1, device=‘cpu’)).detach()
plt.figure(figsize=(15, 15))
plt.imshow(fake)

I am able to view the output as a grid using the following function, but that doesn’t serve my purpose since I would like to save each image later using a loop function.

plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True), (1,2,0)))

I was also viewing the tutorial of https://github.com/ayan-aji-nair/art-generation-gan/blob/main/Notebooks/DCGAN_128.ipynb which is an adaptation of Pytorch’s version.

make_grid will create a grid using all samples in the current batch and transform it into an image format.
Based on your code I guess you are seeing an error claiming that the input shape is invalid.
If so, this error is expected since matplotlib expects input arrays in an image formar as [height, width, channels] where the channels dimension might be missing in case of a grayscale image.
Your model output should have the shape [batch_size, channels, height, width] so you would have to index the sample you would like to visualize in the batch dimension and permute the image to the aforementioned channels-last format.

Hi! Thanks for the reply, I permuted the image and am still getting this error:

This is my code:

# generating new images

# sets a random seed

manualSeed = random.randint(1, 10000) # use if you want new results

print("Random Seed: ", manualSeed)

random.seed(manualSeed)

torch.manual_seed(manualSeed)

fake = gen(torch.randn(1, n_z, 1, 1, device=‘cpu’)).detach()

plt.figure(figsize=(4, 200))

fake.shape

plt.imshow(fake.permute(1,2,0))

This is the error message:

Random Seed: 5002

RuntimeError Traceback (most recent call last)
in ()
13 plt.figure(figsize=(4, 200))
14 fake.shape
—> 15 plt.imshow(fake.permute(1,2,0))
16
17 #plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True), (1,2,0)))

RuntimeError: number of dims don’t match in permute

It seems you forgot to index the sample first before permuting.
Something like this should work: plt.imshow(fake[0].permute(1,2,0))

Thanks alot! I can now see the image but it is in black and no information is available. Also, I am unable to find a way to remove the X and Y numbers plotted as I want to generate art using pytorch and the grids are not going away

Code:

# sets a random seed

manualSeed = random.randint(1, 10000) # use if you want new results

print("Random Seed: ", manualSeed)

random.seed(manualSeed)

torch.manual_seed(manualSeed)

plt.imshow(np.transpose(fake[0], (1, 2, 0)))

Output

You are receiving a warning that your data will be clipped, which could create the black image.
Check the values in the tensor you would like to visualize and make sure they are either floating point values in [0, 1] or uint8 values in [0, 255].
To remove the axes you could use plt.axis('off') and refer to the matplotlib docs for more information.