Fold and Unfold: How do I put this image tensor back together again?

I am trying to filter a single channel 2D image of size 256x256 using unfold to create 16x16 blocks with an overlap of 8. This is shown below:

    *# I = [256, 256] image*
    kernel_size = 16
    stride = bx/2
    patches = I.unfold(1, kernel_size, int(stride)).unfold(0, kernel_size, int(stride)) # [31, 31, 16, 16]
    *Begin filtering...*   

I have started to attempt to put the image back together with fold but I’m not quite there yet. I’ve tried to use view to get the image to ‘fit’ the way it’s supposed to but I don’t see how this would preserve the original image. Perhaps I’m overthinking this.

    # patches.shape = [31, 31, 16, 16]
    patches = = filt_data_block.contiguous().view(-1, kernel_size*kernel_size) # [961, 256]
    patches = patches.permute(1, 0) # [951, 256]

Any help would be greatly appreciated. Thanks very much.

You could take inspirationfrom the torch.nn.Fold example.

I solved the issue and was able to create a 2D sliding window which seamlessly stitched the image back together in the end.

I used ptrblck’s response on this post and adapted my tensor to work. I changed my image from [256, 256] to standard shape B x C x H x W (1 x 1 x 256 x 256). Then I unfolded using my desired stride and kernel dimensions:

# CREATE THE UNFOLDED IMAGE SLICES
I = image           # shape [256, 256]
kernel_size = bx    #shape [16]
stride = int(bx/2)  #shape [8]
I2 = I.unsqueeze(0).unsqueeze(0) #shape [1, 1, 256, 256]
patches2 = I2.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
#shape [1, 1, 31, 31, 16, 16]

Next, I applied my transforms and filters. Note that in my case, I used cosine windows and normalised my data blocks so that when the tensor was folded back together, the overlapping sections summed back together correctly. It would be more simple to use a mask for this, there are some good examples on this forum and StackOverflow. For completeness sake, here is the cosine windowing and normalising I implemented:

# NORMALISE AND WINDOW
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
mean_patches = torch.mean(patches2, (4, 5), keepdim=True)
mean_patches = mean_patches.repeat(1, 1, 1, 1, 16, 16)
window_patches = win.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 31, 31, 1, 1)
zero_mean = patches2 - mean_patches
windowed_patches = zero_mean * window_patches

#SOME FILTERING ....

#ADD MEAN AND WINDOW BEFORE FOLDING BACK TOGETHER.
filt_data_block = (filt_data_block + mean_patches*window_patches) * window_patches

Next, I reshaped my tensor so that it would be in the correct format for the fold() function, and applied the funtion:

# REASSEMBLE THE IMAGE USING FOLD
patches = filt_data_block.contiguous().view(1, 1, -1, kernel_size*kernel_size)
patches = patches.permute(0, 1, 3, 2)
patches = patches.contiguous().view(1, kernel_size*kernel_size, -1)
IR = F.fold(patches, output_size=(256, 256), kernel_size=kernel_size, stride=stride)
IR = IR.squeeze()

With this, the output image came back together with no artifacts and in the correct size!