How to split tensors with overlap and then reconstruct the original tensor?

I don’t think Fold has a built-in “reduction” operation, which you could then pick to e.g. a max operation, so you might need to unfold the patches and reconstruct the output manually using max on the overlaps.

Thanks! Do you have a code example at hand to do that?
Regards,

I’m currently studying image restoration (denoising) and facing the same problem as above mentioned.
In my case, the model can only denoise the fixed resolution with [1, 3, 256, 256](B, C, H, W).
And I want to denoise the arbitrary resolution such as 512x768 (Kodak24) and 321x481 (CBSD68).
The following is my code, it summarizes from the above coding experts’ advisions:

def overlapped_square(timg, kernel=256, stride=128):
    patch_images = []
    b, c, h, w = timg.size()
    # 321, 481
    X = int(math.ceil(max(h, w) / float(kernel)) * kernel)
    img = torch.zeros(1, 3, X, X).type_as(timg)  # 3, h, w
    mask = torch.zeros(1, 1, X, X).type_as(timg)

    img[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
    mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1.0)

    patch = img.unfold(3, kernel, stride).unfold(2, kernel, stride)
    patch = patch.contiguous().view(b, c, -1, kernel, kernel)  # B, C, #patches, K, K
    patch = patch.permute(2, 0, 1, 4, 3)  # patches, B, C, K, K

    for each in range(len(patch)):
        patch_images.append(patch[each])

    return patch_images, mask, X


for file_ in files:
    img = Image.open(file_).convert('RGB')
    input_ = TF.to_tensor(img).unsqueeze(0).cuda()

    with torch.no_grad():
        # pad to multiple of 256
        square_input_, mask, max_wh = overlapped_square(input_.cuda(), kernel=model_img, stride=stride)
        output_patch = torch.zeros(square_input_[0].shape).type_as(square_input_[0])

        for i, data in enumerate(square_input_):
            restored = model(square_input_[i])
            if i == 0:
                output_patch += restored
            else:
                output_patch = torch.cat([output_patch, restored], dim=0)

        B, C, PH, PW = output_patch.shape
        weight = torch.ones(B, C, PH, PH).type_as(output_patch)  # weight_mask

        patch = output_patch.contiguous().view(B, C, -1, model_img*model_img)
        patch = patch.permute(2, 1, 3, 0)  # B, C, K*K, #patches
        patch = patch.contiguous().view(1, C*model_img*model_img, -1)

        weight_mask = weight.contiguous().view(B, C, -1, model_img * model_img)
        weight_mask = weight_mask.permute(2, 1, 3, 0)  # B, C, K*K, #patches
        weight_mask = weight_mask.contiguous().view(1, C * model_img * model_img, -1)

        restored = F.fold(patch, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
        we_mk = F.fold(weight_mask, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
        restored /= we_mk

        restored = torch.masked_select(restored, mask.bool()).reshape(input_.shape)
        restored = torch.clamp(restored, 0, 1)

    restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
    restored = img_as_ubyte(restored[0])

    f = os.path.splitext(os.path.split(file_)[-1])[0]
    save_img((os.path.join(out_dir, f + '.png')), restored)

The upper function overlapped_square() is used to padding to the multiple of 256.

Thanks everybody for helping!

Feel free to discuss it below if you have more ideas.

Same here. @ptrblck consider adding this fix to your (almost perfect!) recipe.

If the width is variable for every images, how can I do?

If each image has a different shape, you could unfold each image separately.

1 Like

I’ve a problem when I reassemble the patches. This is my original image:

original

The shape of patches is:

torch.Size([1024, 3, 32, 32])

With your code, I obtain this:

reconstructed

The output.shape is correct (torch.Size([1, 3, 256, 4096])). How can I do?

I’ve found the solution. From @ptrblck snippet:

B, C, W, H = 2, 3, 1024, 1024
x = torch.randn(B, C, H, W)

kernel_size = 128
stride = 64
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]

# perform the operations on each patch
# ...

# reshape output to match F.fold input
patches = patches.contiguous().transpose(1,0).view(B, C, -1, kernel_size*kernel_size).transpose(0,1)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold

output = F.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]

Hi @ptrblck
I’ve built a naive function to reconstruct my images from patches, using max operation on overlaps:

def max_fold(self, maps: torch.Tensor, output_size: tuple, 
    kernel_size: tuple, stride: tuple
    ) -> torch.Tensor:
    
    output = torch.zeros((1, maps.shape[1], *output_size))

    fn = lambda x: [[i, i+kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1]
    locs = [[*h, *w] for h in fn(0) for w in fn(1)]

    for loc, m in zip(locs, maps):
        patch = torch.zeros(output.shape)
        patch[:,:, loc[0]:loc[1], loc[2]:loc[3]] = m
        output = torch.max(output, patch)

    return output

It works but is 4 times slower than built-in methods…
Do you have any ideas to improve this function?
Cheers!

You might want to use unfold as described in previous posts instead of indexing the tensor in a loop.

Thanks for the reply!
Okay but how can I accumulate max values without a loop?

patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride) becomes patches = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride) also help

Hi, we support patch extraction and combine back in kornia extract_tensor_patches.

There’s also a lot of well supported tooling and examples around extraction at different scales and geometric transformations since it’s the base of a lot computer vision local features detection.