How does np.transpose() permute the axis

I am trying to visualize the images in one batch of the data loader. I got this code snippet from Pytorch official site.

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1,2,0))
    print(inp.shape)
    plt.imshow(inp)
    if title is not None:
         plt.title(title)
    # plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
image, label = next(iter(train_loader))
print(image.shape)
#print(label)
# Make a grid from batch
out = torchvision.utils.make_grid(image)
imshow(out, title=[label[x] for x in label])

However, I am not able to understand inp.numpy().transpose((1,2,0)), what transpose is doing here. I know that it permutes the axis. However, the corresponding tensor shape is torch.Size([4, 3, 224, 224]). In that case (1,2,0) corresponds to 3,224,4. Then how imshow is correctly displaying images?

Apart from that, the shape of the tensor image is 3,224,224. but when it is being transformed to ndarray why the shape is being changed to (228, 906, 3). Should it become 224, 224, 3.

Hi Aleemsidra!

You should get an error here. You need to pass four axes to numpy’s
transpose() to transpose a 4-d tensor.

I have no idea where your (228, 906, 3) is coming from. I get
(244, 244, 3), which is what you expect and want. (matplotlib’s
imshow() wants the three color channels as the last dimension,
rather than the first. That’s the purpose of the transpose().)

Here’s an illustration of these points:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> inp = torch.randn (4, 3, 224, 224)
>>> inp.shape
torch.Size([4, 3, 224, 224])
>>> inp.numpy().transpose ((1, 2, 0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: axes don't match array
>>> inp[0].shape
torch.Size([3, 224, 224])
>>> inp[0].numpy().transpose ((1, 2, 0)).shape
(224, 224, 3)

Best.

K. Frank

I realized that I am using make_grid function which is changing the dimension from 4d to 3d. It is also changing the the dimension from 224, 244,3 to 228,906,3. Which is still confusing for me that how it picks up this dimension.

The “grid” of images is created by flattening the images provided as a batch of image tensors (4D) into a single image tensor (3D).
The shapes are created by placing the images next to each other using the specified nrow and padding arguments as described in the docs.