Problem:
Fold/unfold overlapping patches from 3D tensors.
Context:
I am working on a segmentation problem where I make a prediction for each patch, and then use mean pooling to combine predictions from overlapping patches.
Using code from here, non-overlapping patches works great, but I have been unable to adapt the code to overlapping patches.
import torch
# Unfold data
x = torch.randn(1, 256, 256, 256) # batch, c, h, w
kc, kh, kw = 64, 64, 64 # kernel size
dc, dh, dw = 64, 64, 64 # stride
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)
# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)
# Check for equality
print((patches_orig == x).all())
Is it possible to use stride=16 and use mean pooling to combine overlapping patches?
Cheers,
Brendan