How to extract smaller image patches (3D)?

For a patch of 64x64 and a stride of 46, this code should work:

x = torch.randn(143, 143, 284)

kernel_size = 64
stride = 46

# Calculate padding to fit the sliding windows
pad0_left = (x.size(0) // stride * stride + kernel_size) - x.size(0)
pad1_left = (x.size(1) // stride * stride + kernel_size) - x.size(1)
pad2_left = (x.size(2) // stride * stride + kernel_size) - x.size(2)

# Calculate symmetric padding
pad0_right = pad0_left // 2 if pad0_left % 2 ==0 else pad0_left // 2 + 1
pad1_right = pad1_left // 2 if pad1_left % 2 ==0 else pad1_left // 2 + 1
pad2_right = pad2_left // 2 if pad2_left % 2 ==0 else pad2_left // 2 + 1

pad0_left = pad0_left // 2
pad1_left = pad1_left // 2
pad2_left = pad2_left // 2

x = F.pad(x, (pad2_left, pad2_right, pad1_left, pad1_right, pad0_left, pad0_right))
x.shape

ret = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)

I would recommend to verify it by visualizing the patches, as I have just verified it on random input.

1 Like