Fold/unfold to get patches

Hi All.
I have a dataset of 3328x2560, and 4084x3328 images, and I want to get patches of size 256x256. How can I use fold and unfold, to get patches and then put them back together as the original image?
Thanks for your help!

1 Like

Have a look at this post which gives an example of this use case.

1 Like

Thanks for your reply.
My images are 2-dimensions, and I fail to adapt the code. I can do the unfold, but I can not put them back together. Could you please help me on that?

Sure!
Could you post the code you are using to unfold your input?
Note that you would have to pad the inputs of 4084 to 4096 to get non overlapping patches.

1 Like

I just do as you did in the mentioned post, so it’ll be:

x = image
kh, kw = 256, 256 # kernel size
dh, dw = 256, 256 # stride
patches = x.unfold(0, kc, dc).unfold(1, kh, dh)
unfold_shape = patches.size()
patches = patches.contiguous().view(patches.size(0), -1, kc, kh) #unsure
print(patches.shape)

Although I quite don’t understand the line with “unsure” tag.
And yes I know, but is there a way to do it without adding padding? like maybe having overlapping patches?

I don’t think there is an easy way of changing the stride for the last window.
Here is the code for both input sizes:

# 3328x2560
x = torch.randn(3328, 2560)
kh, kw = 256, 256 # kernel size
dh, dw = 256, 256 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
unfold_shape = patches.size()

patches = patches.contiguous().view(-1, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3]
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)

# Check for equality
print((patches_orig == x).all())


# 4084x3328
x = torch.randn(4084, 3328)
kh, kw = 256, 256 # kernel size
dh, dw = 256, 256 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
unfold_shape = patches.size()

patches = patches.contiguous().view(-1, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3] # you will lose some pixels in w
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)

# Check for equality using slicing, since w is now smaller in x
print((patches_orig == x[:output_h, :output_w]).all())

As you can see, you’ll lose some pixels in dim1 for the second input shape.

4 Likes