What is the correct way of using trchvision.utils.make_grid()

Hello all, I have been trying to use this method but fail each time. basically I have two images that I stored in a list (using img_lst.append((img1,img2)). one image is the input image and the other is its reconstruction.
They are both gray scale images.
This is how I tried to do :

def visualize(imgs_list, rows=2, cols=10):
    fig = plt.figure(figsize=(10,5))
    print(f'number of samples: {len(imgs_list)}')
    for i in range(len(imgs_list)):
        img,recons = imgs_list[i]
        img = img.cpu().permute(1,2,0)
        recons = recons.cpu().detach().permute(1,2,0)
        f = [img,recons]
        ax = fig.add_subplot(rows, cols, i+1, xticks=[], yticks=[])
        ax.imshow(torchvision.utils.make_grid(f), cmap='Greys_r')

but this fails with the error :

RuntimeError: The expanded size of the tensor (3) must match the existing size (28) at non-singleton dimension 0. Target sizes: [3, 28, 1]. Tensor sizes: [28, 28, 1]

if I remove permute, it complains about the dimensions being invalid :

TypeError : Invalid dimensions for image data

I’m completely lost here! how is this supposed to work?
Can anyone kindly please help me understand this?
Thanks a lot in advance

Hi @Shisho_Sama

The images have to be in the CHW layout (channel first), with img.shape == recons.shape.
Can you check the shapes by printing them?

They are in CHW. and here is the print output for
print(img.shape,recons.shape)
torch.Size([1, 28, 28]) torch.Size([1, 28, 28])

Are you sure that the error comes from make_grid? It might come from imshow as well. Can you try to run it without ax.imshow?

2 Likes

You were right! The error was caused by the matplotlibs imshow().
I had to do sth like this :

        x = torchvision.utils.make_grid([img,recons])
        ax.imshow(x.numpy().transpose(1,2,0))
1 Like

Oh right, it just came to my mind, matlpotlib’s imshow expects HWC images!
I’m glad it worked :slightly_smiling_face:

1 Like