Creating patches from image

If I want to split grayscale images into 14x14 patches (4 pieces total for each image in the code below) I tried the following code:

S = 1 # channel dim
W = 28 # width
H = 28 # height
batch_size = 10

x = torch.randn(batch_size, S, W, H)

size = 14 # patch size
stride = 14 # patch stride
patches = x.unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

Which returns:
torch.Size([10, 1, 2, 2, 14, 14])

I believe the shape I want is torch.Size([10, 1, 4, 14, 14])
Which if you viewed each of the 4 pieces it should show each of the four quadrants of the image (in this example at least). So four sub-images each 14x14 pixels. How would I unfold the images to get this result?

1 Like

You already have the results. The 4 patches are stored as 2x2 patches in dim2 and dim3, so you can just call .contiguous().view(batch_size, S, -1, size, size) on the result.

2 Likes

It just hit me right before I saw your answer. LOL Not sure why I didn’t consider it before asking the question. Thanks for the fast reply!

1 Like

use this: https://kornia.readthedocs.io/en/latest/contrib.html#kornia.contrib.extract_tensor_patches