# Split an image into a 2 by 2 grid

So right now I’m trying to split a tensor (of an image) of dimension 3 x 32 x 32 into a grid, then stack each part of the grid on top of each other to for a 4 x 3 x 16 x 16 tensor. The 0-th element of the tensor in dimension 0 will be the top left piece of the image, the 1-th element will be the top right, the 2-th element will be the bottom left piece, and the 3-th element will be the bottom right piece.

My code so far looks like

``````def grid_image(image_tensor, grid_length=2):
n_channels, n_rows, n_cols = image_tensor.shape
row_length = n_rows // grid_length
col_length = n_cols // grid_length

patches = []
rows = torch.split(image_tensor,row_length,1)
for row in rows:
row_patches = torch.split(row, col_length, 2)
patches += row_patches

``````

I feel like there is a better way to do this because I am currently splitting the image by rows then columns. I’m looking for an approach that does both at the same time, which would assumedly improve the time complexity of the function. I tried searching around and I didn’t find anything. Could I please have some help?

`tensor.unfold` should work.

1 Like

Thanks for your answer ptrblck, I managed to get the code below for a grid of any size.

``````patches = (image_tensor
.unfold(1,row_length,col_length)
.unfold(2,row_length,col_length)
.reshape(3,grid_length**2,row_length,col_length)
.permute(1,0,2,3))
``````