Problem with torch.view

Sorry, I’m new and this is my first question.

I’m trying to plot a PyTorch tensor image, which shape is (channels, width, height), so for example (3, 256, 256). To plot such image I’m using pyplot.imshow of matplotlib, but this method requires a ndarray which shape is (width, height, channels), so (256, 256, 3).

To do this transformation I always used torch.view. However, is figured out that such method repeats the image as you can see in the left image.

Instead, I have no problem transforming the tensor in numpy and then using img = np.einsum(“xyz -> yzx”, img) as you can see in the right image.

Could you help me to understand the differences between the methods? Moreover, torch.view could be used for transpose? Thanks.

   ### This code produces the second image ###

    # Shape: out[2] = (3, 256, 256)
    img = out[2].numpy()
    img = np.einsum("xyz -> yzx", img)
    plt.imshow(img)
    plt.show()
    plt.clf()
    plt.close()


   ### This code produces the first image ###

    # Shape: out[2] = (3, 256, 256)
    img = out[2].view(256, 256, 3)
    plt.imshow(img)
    plt.show()
    plt.clf()
    plt.close()

My thinking is that what we want is to permute the axis so we want to do

x = torch.randn((3, 256, 256))
x = x.permute(1,2,0) # shape (256,256,3)

And using einsum is equivalent to permuting the axis. This is not what torch.view does. I think one explanation of what view does is that it will rearrange the dimensions to match, not swap them which is what permute is doing. You might want to check out this post to see a better example of what view is doing

Thanks for your explanation!