If I want to split grayscale images into 14x14 patches (4 pieces total for each image in the code below) I tried the following code:
S = 1 # channel dim
W = 28 # width
H = 28 # height
batch_size = 10
x = torch.randn(batch_size, S, W, H)
size = 14 # patch size
stride = 14 # patch stride
patches = x.unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
Which returns:
torch.Size([10, 1, 2, 2, 14, 14])
I believe the shape I want is torch.Size([10, 1, 4, 14, 14])
Which if you viewed each of the 4 pieces it should show each of the four quadrants of the image (in this example at least). So four sub-images each 14x14 pixels. How would I unfold the images to get this result?