How to split tensors with overlap and then reconstruct the original tensor?

fold should work in your use case.
Here is a small example creating the expected input shape step by step:

B, C, W, H = 2, 3, 1024, 1024
x = torch.randn(B, C, H, W)

kernel_size = 128
stride = 64
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]

# perform the operations on each patch
# ...

# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold

output = F.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]

Let me know, if this works for you. :slight_smile: