Cropped Tensor mixed result

Hi, I’m trying to crop a section of a 4 component tensor [batch, channel, height, weight] that was originally a numpy image. I am using numpy-like indexing, here’s the code:

    # img_mod is a pytorch tensor that was a numpy array
    b = img_mod.shape[0]
    c = img_mod.shape[1]
    h = img_mod.shape[2]
    w = img_mod.shape[3]

    # Trying to crop the first quadrant
    img_mod_frag = img_mod[:, :, 0:h//2, 0:w//2]

    # Reshape the components to make the tensors displayable
    # Display img_mod
    plt.imshow(img_mod.reshape(h, w, c))
    plt.title("img_mod")
    plt.show()

    # Display img_mod_frag
    plt.imshow(img_mod_frag.reshape(h//2, w//2, c))
    plt.title("img_mod_frag")
    plt.show()

My intention is to crop the first quadrant of the image but these are the results:

The resulting image is ‘mixed’, I don’t know if this is the correct behaviour of tensors. Then, I would like to process the fragment in a simple CNN, but even if I crop the image properly transforming it to numpy and indexing It, when I process It with my model() function It ‘mixes’ the fragment too. Like the img_mod_frag case.

My question is if this is the correct way to crop propperly a tensor quadrant. If it’s not, What’s the correct way? Am I doing something wrong? or should I train my model with these ‘mixed’ fragments?

Thanks.

You should always be careful with the reshape method since it works respecting the order of dimensions! You can either use transpose or permute to change the order of dimensions. Regarding the cropping tensors, I suggest you take a look at torchvision/transforms.

Here’s one example using the permute method:

from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms

image = Image.open("Lenna.png")
img_mod = transforms.ToTensor()(image).unsqueeze(0)

# img_mod is a pytorch tensor that was a numpy array
b = img_mod.shape[0]
c = img_mod.shape[1]
h = img_mod.shape[2]
w = img_mod.shape[3]

# Trying to crop the first quadrant
img_mod_frag = img_mod[:, :, 0:h//2, 0:w//2]

# Reshape the components to make the tensors displayable
# Display img_mod
fig = plt.figure()
ax = fig.add_subplot(1, 2, 1)
ax.imshow(img_mod[0].permute(1, 2, 0))
plt.title("img_mod")

# Display img_mod_frag
ax = fig.add_subplot(1, 2, 2)
ax.imshow(img_mod_frag[0].permute(1, 2, 0))
plt.title("img_mod_frag")

plt.savefig("Lenna2.png")

Lenna2

Thank you so much! I clearly see how It works now.