How to split tensors with overlap and then reconstruct the original tensor?

I encountered a problem. My network is trained with tensors of size BxCx128x128, but I need to verify its image reconstruction performance with images of size 1024x1024. To make the reconstruction smooth, I need to split my input of size BxCx1024x1024 into BxCx128x128 tensors with overlap, which are then fed to the network for reconstruction. Then, the reconstructed tensors of size BxCx128x128 should be used for reconstructing the tensor of size BxCx1024x1024 by averaging the overlaping elements. Note that (Size(in)-128)/stride may be not a integer. How to use padding strategy to ensure that every element is cropped at least for one time? How to implement the process of recovering the BxCx1024x1024 tensor from overlapping BxCx128x128 tensors? Could anyone give me some suggestions? Thanks advance for your consideration.

2 Likes

fold should work in your use case.
Here is a small example creating the expected input shape step by step:

B, C, W, H = 2, 3, 1024, 1024
x = torch.randn(B, C, H, W)

kernel_size = 128
stride = 64
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]

# perform the operations on each patch
# ...

# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold

output = F.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]

Let me know, if this works for you. :slight_smile:

6 Likes

Thanks for your reply. In fact, I have noticed these two functions in other related posts (e.g., Patch Making Does Pytorch have Anything to Offer?), but there are still several questions.

  1. In the example, the stride is set to 64, which satisfies (Size(in)-128)/164 to be an integer. But how to handle arbirary strides that may not satisfy this (e.g. stride=40) with padding or other strategies?
  2. When we get patches with size being [B,C,nb_patches_h,nb_patches_w,kernel_size,kernel_size], we can surely perform the operation on each patch by looping over patches in dimension 0 and 1. Is there a way that we can perform the operation on all the patch one-time?
  3. Is the spatial neighborhood information preserved in unfold and fold process? The reconstructed large tensor should average the overlapping elements in tensor patches, but I did not see any signs of this process.

Many thanks for your help.

  1. Padding should work fine.

  2. Depending on the operation you could use e.g. a single multiplication as seen here.

  3. Yes, the overlapping elements are however summed so you might need to normalize them afterwards with e.g. a mask.

Thanks for your reply.

  1. How to automatically use padding in unfold? Or just padding before unfolding operation?
  2. In your posted example, why do we need to permute patches with parameters (0,1,3,2)? What is the difference between
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold

and

patches2=patches.contiguous().view(B,C,-1,kernel_size*kernel_size)
print(patches2.shape)
patches2=patches2.view(B,C*kernel_size*kernel_size,-1)
print(patches2.shape)

Thanks a lot!

  1. Pad the data before unfolding.

  2. If you are viewing collapsing non-neighboring dimensions, you are interleaving the pixels. You could try it on an image tensor and your result should be interleaved.

Thanks for your reply.
I am a pytorch beginner. Could you please give a function including padding that can handle arbitrary strides and recover the original tensor?

This post is using padding so that the original input tensor can be restored.

Many thanks to your reply. You are so warm-hearted!

1 Like

Hi @ptrblck,

My aim is the same as the OP: unfold a large image into overlapping tiles, then fold them back together, averaging values where there was overlap. I’ve been testing this example with dummy image data, but it doesn’t seem to be returning the overlapped sum.

from skimage.data import astronaut

def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.
    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (image_numpy + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling

    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


tpi = transforms.ToPILImage()
tform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])
        ])

x = tform(astronaut()).unsqueeze(0)
B, C, W, H = x.shape

kernel_size = 64
stride = 32
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]

# perform the operations on each patch
# ...

# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold

output = F.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]
# Take a look at the input
tpi(tensor2im(x).transpose(1,2,0))

# Take a look at the output
tpi(tensor2im(output).transpose(1,2,0))

The overflow artifacts are expected here and are easily corrected by division with a mask generated by running a tensor of ones through the unfold/fold operation:


…but what we can see is that the patches have not been reassembled as expected. We do get the same shape, but there is substantial scrambling.

Any help on this problem would be much appreciated. Cheers!

1 Like

Just figured it out. Had to permute the patches:

patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
becomes
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride).permute(0,1,2,3,5,4)

1 Like

Hi eburling,

I am trying to do the same thing (unfold a large image into overlapping tiles, then fold them back together, averaging values where there was overlap). I see you figured out how to unfold and fold the image back together. How did you take care of the averaging overlapping values part?

Hi @whuang7000,

I generated a normalization mask by unfolding/refolding a torch.ones tensor of equal size, then divided my output image by the normalization mask. Probably a better way out there, but it worked for me.

1 Like

I also want to do the same, but I also have a target image, which means I need to split input and target, so for each input patch create its corresponding target patch. how can I do that? also is there way to split the input into unsquare patch, since my original image are large when I gave them to the super resolution model I got run out of memory, so I want to divide input by half and reconstruct each sub image and then stick them to reconstruct the whole image.
Thanks

You could use tensor.unfold on both, the input and target tensors (or alternatively the nn.Unfold module).
In both approaches you can specify the kernel size and could use non-square shapes for it.

Thanks for your answer. But my input and target are not the same size, the target image is larger than input with specefice scale factor.

This would mean that you would either end up with more patches in the target or with bigger patches (larger kernel size in the unfold operation). Based on your initial description it seems you would like to create corresponding target patches for each input patch, so scaling up the kernel size (and stride) should probably work.

Thanks for your help I did that but still I have a problem. What I want to do is that I have an input image let’s note it as LR, and the target image takes it as HR, I fed the LR image into my model(Super-resolution reconstruction model)and the output of the model is SR image. SR and HR have the same scale and they are larger than LR by scale factor m. As my LR images are large if I directly fed the LR image into the model I get out of memory, that’s why I want to divide it into the smaller patches and reconstruct its corresponding SR, and then combine them to create the final SR image, the result in this way should be the same as when I directly input whole LR into the model. But after applying fold on SR_Patches, the SR image result is messy. Do you have any suggestions on how can solve this issue?
Thanks.

LR_Image.shape :
torch.Size([3, 702, 1020])
HR_Image.shape :
torch.Size([ 3, 1404, 2040])
scale=2
lr_kernel_size= (351,510)
hr_kernel_size= (702,1020)
lr_stride =lr_kernel_size
hr_stride =hr_kernel_size * scale
fold_params = dict(kernel_size=lr_kernel_size, stride=lr_stride)
unfoldlr = nn.Unfold(**fold_params)
lr_patches=unfoldlr(LR_Image)
lr_patches:
torch.Size([1, 537030, 4])
unfoldhr = nn.Unfold(**dict(kernel_size=hr_kernel_size, stride=hr_stride))
hr_patches = unfoldhr(HR_Image)
hr_patches:
torch.Size([1, 2148120, 4])
sr_patches= self.model(lr_patches, idx_scale)
sr_patches: torch.Size([1, 2148120, 4])
fold = nn.Fold(( 1404, 2040), **dict(kernel_size=hr_kernel_size, stride=hr_stride))
sr = fold(sr_patches)
sr.shape:
torch.Size([1, 3, 1404, 2040])

That wouldn’t be guaranteed, as your model only sees patches and could thus create artifacts on the borders of these patches.
You could try to use some overlapping (and try to sum these overlapping pixels in the output or use another reduction), but I’m unsure which approach would work the best.

Hi @ptrblck ,

I would like to get the maximum value instead of averaging in overlapping areas, do you have a simple and effective idea to do that?

Thanks in advance for your help!