So right now I’m trying to split a tensor (of an image) of dimension 3 x 32 x 32 into a grid, then stack each part of the grid on top of each other to for a 4 x 3 x 16 x 16 tensor. The 0-th element of the tensor in dimension 0 will be the top left piece of the image, the 1-th element will be the top right, the 2-th element will be the bottom left piece, and the 3-th element will be the bottom right piece.

My code so far looks like

```
def grid_image(image_tensor, grid_length=2):
n_channels, n_rows, n_cols = image_tensor.shape
row_length = n_rows // grid_length
col_length = n_cols // grid_length
patches = []
rows = torch.split(image_tensor,row_length,1)
for row in rows:
row_patches = torch.split(row, col_length, 2)
patches += row_patches
return torch.stack(patches)
```

I feel like there is a better way to do this because I am currently splitting the image by rows then columns. I’m looking for an approach that does both at the same time, which would assumedly improve the time complexity of the function. I tried searching around and I didn’t find anything. Could I please have some help?