Fold and Unfold: How do I put this image tensor back together again?

I solved the issue and was able to create a 2D sliding window which seamlessly stitched the image back together in the end.

I used ptrblck’s response on this post and adapted my tensor to work. I changed my image from [256, 256] to standard shape B x C x H x W (1 x 1 x 256 x 256). Then I unfolded using my desired stride and kernel dimensions:

# CREATE THE UNFOLDED IMAGE SLICES
I = image           # shape [256, 256]
kernel_size = bx    #shape [16]
stride = int(bx/2)  #shape [8]
I2 = I.unsqueeze(0).unsqueeze(0) #shape [1, 1, 256, 256]
patches2 = I2.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
#shape [1, 1, 31, 31, 16, 16]

Next, I applied my transforms and filters. Note that in my case, I used cosine windows and normalised my data blocks so that when the tensor was folded back together, the overlapping sections summed back together correctly. It would be more simple to use a mask for this, there are some good examples on this forum and StackOverflow. For completeness sake, here is the cosine windowing and normalising I implemented:

# NORMALISE AND WINDOW
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
mean_patches = torch.mean(patches2, (4, 5), keepdim=True)
mean_patches = mean_patches.repeat(1, 1, 1, 1, 16, 16)
window_patches = win.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 31, 31, 1, 1)
zero_mean = patches2 - mean_patches
windowed_patches = zero_mean * window_patches

#SOME FILTERING ....

#ADD MEAN AND WINDOW BEFORE FOLDING BACK TOGETHER.
filt_data_block = (filt_data_block + mean_patches*window_patches) * window_patches

Next, I reshaped my tensor so that it would be in the correct format for the fold() function, and applied the funtion:

# REASSEMBLE THE IMAGE USING FOLD
patches = filt_data_block.contiguous().view(1, 1, -1, kernel_size*kernel_size)
patches = patches.permute(0, 1, 3, 2)
patches = patches.contiguous().view(1, kernel_size*kernel_size, -1)
IR = F.fold(patches, output_size=(256, 256), kernel_size=kernel_size, stride=stride)
IR = IR.squeeze()

With this, the output image came back together with no artifacts and in the correct size!

3 Likes