Split an image into a 2 by 2 grid

So right now I’m trying to split a tensor (of an image) of dimension 3 x 32 x 32 into a grid, then stack each part of the grid on top of each other to for a 4 x 3 x 16 x 16 tensor. The 0-th element of the tensor in dimension 0 will be the top left piece of the image, the 1-th element will be the top right, the 2-th element will be the bottom left piece, and the 3-th element will be the bottom right piece.

My code so far looks like

def grid_image(image_tensor, grid_length=2):
    n_channels, n_rows, n_cols = image_tensor.shape
    row_length = n_rows // grid_length
    col_length = n_cols // grid_length

    patches = []
    rows = torch.split(image_tensor,row_length,1)
    for row in rows:
        row_patches = torch.split(row, col_length, 2)
        patches += row_patches

    return torch.stack(patches)

I feel like there is a better way to do this because I am currently splitting the image by rows then columns. I’m looking for an approach that does both at the same time, which would assumedly improve the time complexity of the function. I tried searching around and I didn’t find anything. Could I please have some help?

tensor.unfold should work.

1 Like

Thanks for your answer ptrblck, I managed to get the code below for a grid of any size.

patches = (image_tensor
               .unfold(1,row_length,col_length)
               .unfold(2,row_length,col_length)
               .reshape(3,grid_length**2,row_length,col_length)
               .permute(1,0,2,3))