Trying to understand torch.transpose

I am trying to go through the steps in (http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) regarding to training a classifier.
In this tutorial we have plt.imshow(np.transpose(npimg, (1, 2, 0))) that show the npimg after transposing it.
Can you explain how it exactly transpose it?

based on torch.transpose(http://pytorch.org/docs/master/torch.html) , if (1,2,0) is corresponds to (dim0,dim1,dim2) what it exactly does?
Sorry but cannot understand what is going on here…
Thanks

Hi, in PyTorch, the order of dimension is channel*width*height but in matplotlib it’s width*height*channel. That’s why the transpose is needed.

3 Likes

Numpy and Matplotlib actually represents 3 dimension in x,y,z format. Whereas pytorch represents them as z,x,y. So with (np.transpose(npimg, (1, 2, 0)) we are telling numpy the mapping of the indexed. i.e.
for a torch Image matrix with shape (3,32,32) where 3 is z (channel) 32 is x and 32 is y (np.transpose(npimg, (1, 2, 0)) says transpose the given matrix npimg into numpy where value of x,y.z are at index 1,2,0 respectively.

4 Likes