Fold and Unfold: How do I put this image tensor back together again?

I am trying to filter a single channel 2D image of size 256x256 using unfold to create 16x16 blocks with an overlap of 8. This is shown below:

    *# I = [256, 256] image*
    kernel_size = 16
    stride = bx/2
    patches = I.unfold(1, kernel_size, int(stride)).unfold(0, kernel_size, int(stride)) # [31, 31, 16, 16]
    *Begin filtering...*   

I have started to attempt to put the image back together with fold but I’m not quite there yet. I’ve tried to use view to get the image to ‘fit’ the way it’s supposed to but I don’t see how this would preserve the original image. Perhaps I’m overthinking this.

    # patches.shape = [31, 31, 16, 16]
    patches = = filt_data_block.contiguous().view(-1, kernel_size*kernel_size) # [961, 256]
    patches = patches.permute(1, 0) # [951, 256]

Any help would be greatly appreciated. Thanks very much.

You could take inspirationfrom the torch.nn.Fold example.

I solved the issue and was able to create a 2D sliding window which seamlessly stitched the image back together in the end.

I used ptrblck’s response on this post and adapted my tensor to work. I changed my image from [256, 256] to standard shape B x C x H x W (1 x 1 x 256 x 256). Then I unfolded using my desired stride and kernel dimensions:

# CREATE THE UNFOLDED IMAGE SLICES
I = image           # shape [256, 256]
kernel_size = bx    #shape [16]
stride = int(bx/2)  #shape [8]
I2 = I.unsqueeze(0).unsqueeze(0) #shape [1, 1, 256, 256]
patches2 = I2.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
#shape [1, 1, 31, 31, 16, 16]

Next, I applied my transforms and filters. Note that in my case, I used cosine windows and normalised my data blocks so that when the tensor was folded back together, the overlapping sections summed back together correctly. It would be more simple to use a mask for this, there are some good examples on this forum and StackOverflow. For completeness sake, here is the cosine windowing and normalising I implemented:

# NORMALISE AND WINDOW
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
mean_patches = torch.mean(patches2, (4, 5), keepdim=True)
mean_patches = mean_patches.repeat(1, 1, 1, 1, 16, 16)
window_patches = win.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 31, 31, 1, 1)
zero_mean = patches2 - mean_patches
windowed_patches = zero_mean * window_patches

#SOME FILTERING ....

#ADD MEAN AND WINDOW BEFORE FOLDING BACK TOGETHER.
filt_data_block = (filt_data_block + mean_patches*window_patches) * window_patches

Next, I reshaped my tensor so that it would be in the correct format for the fold() function, and applied the funtion:

# REASSEMBLE THE IMAGE USING FOLD
patches = filt_data_block.contiguous().view(1, 1, -1, kernel_size*kernel_size)
patches = patches.permute(0, 1, 3, 2)
patches = patches.contiguous().view(1, kernel_size*kernel_size, -1)
IR = F.fold(patches, output_size=(256, 256), kernel_size=kernel_size, stride=stride)
IR = IR.squeeze()

With this, the output image came back together with no artifacts and in the correct size!

3 Likes

Thanks for adding the comments, they really helped with reproducing.

For anyone else coming here struggling with the unfolding and reconstruction of an image, I created a runnable code gist based on @Bled_Clement `s excellent solution. Just execute it in a notebook. As you can see, this also works for batches of images.

# %pip install torch torchvision numpy pillow scikit-image

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from skimage.data import coffee

gray_astronaut = np.uint8(np.average(astronaut(), axis=2))
astro_t = torch.Tensor(gray_astronaut).unsqueeze(0)[:, :250, :250]

gray_coffee = np.uint8(np.average(coffee(), axis=2))
coffee_t = torch.Tensor(gray_coffee).unsqueeze(0)[:, :250, :250]

t = t = torch.stack([coffee_t, astro_t], dim=0)
print(f't shape = {t.shape}')
# t shape = torch.Size([2, 1, 250, 250])

kernel_size = 200
stride = 50

unfolded = t.unfold(2, size=kernel_size, step=stride).unfold(3, size=kernel_size, step=stride)
print(f'unfolded shape = {unfolded.shape}')
# unfolded shape = torch.Size([2, 1, 2, 2, 200, 200])

# reshape, permute, and generally transform the tensor.
view = unfolded.contiguous().view(t.shape[0], 1, -1, kernel_size * kernel_size)
print(f'view shape = {view.shape}')
# view shape = torch.Size([2, 1, 4, 40000])

permute = view.permute(0, 1, 3, 2)
print(f'permute shape = {permute.shape}')
# permute shape = torch.Size([2, 1, 40000, 4])

patches = permute.contiguous().view(t.shape[0], kernel_size * kernel_size, -1)
print(f'patches shape = {patches.shape}')
# patches shape = torch.Size([2, 40000, 4])

folded = F.fold(patches, output_size=(250,250), kernel_size=kernel_size, stride=stride)
counts = F.fold(torch.ones_like(patches), output_size=(250,250), kernel_size=kernel_size, stride=stride)
print(f'folded shape = {folded.shape}')
# folded shape = torch.Size([2, 1, 250, 250])
print(f'counts shape = {counts.shape}')
#counts shape = torch.Size([2, 1, 250, 250])

# Divide the folded tensor by the counts tensor to get the original pixel values.
result = (folded / counts)
assert torch.eq(result, t).all()

Image.fromarray(np.uint8(result[0,0,:,:]), 'L')
2 Likes