Fold/unfold to get patches

I don’t think there is an easy way of changing the stride for the last window.
Here is the code for both input sizes:

# 3328x2560
x = torch.randn(3328, 2560)
kh, kw = 256, 256 # kernel size
dh, dw = 256, 256 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
unfold_shape = patches.size()

patches = patches.contiguous().view(-1, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3]
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)

# Check for equality
print((patches_orig == x).all())


# 4084x3328
x = torch.randn(4084, 3328)
kh, kw = 256, 256 # kernel size
dh, dw = 256, 256 # stride
patches = x.unfold(0, kh, dh).unfold(1, kw, dw)
unfold_shape = patches.size()

patches = patches.contiguous().view(-1, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_h = unfold_shape[0] * unfold_shape[2]
output_w = unfold_shape[1] * unfold_shape[3] # you will lose some pixels in w
patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)

# Check for equality using slicing, since w is now smaller in x
print((patches_orig == x[:output_h, :output_w]).all())

As you can see, you’ll lose some pixels in dim1 for the second input shape.

4 Likes