Fold patches of images back to single image

I have function that translates image into puzzles using unfold:

def make_jigsaw_puzzle(x, grid_size=(2, 2)):
    # x shape is C x H x W
    C, H, W = x.size()

    assert H % grid_size[0] == 0
    assert W % grid_size[1] == 0

    C, H, W = x.size()
    x_jigsaw = x.unfold(1, H // grid_size[0], W // grid_size[1])
    x_jigsaw = x_jigsaw.unfold(2, H // grid_size[0], W // grid_size[1])
    x_jigsaw = x_jigsaw.contiguous().view(-1, C,  H // grid_size[0], W // grid_size[1])
    return x_jigsaw

x_jigsaw shape is grid_size[0] * grid_size[1] x C x H x W.
How can I get image from patches back?

This post might help.

Thank you! It works:

def jigsaw_to_image(x, grid_size=(2, 2)):
    # x shape is batch_size x num_patches x c x jigsaw_h x jigsaw_w
    batch_size, num_patches, c, jigsaw_h, jigsaw_w = x.size()
    assert num_patches == grid_size[0] * grid_size[1]
    x_image = x.view(batch_size, grid_size[0], grid_size[1], c, jigsaw_h, jigsaw_w)
    output_h = grid_size[0] * jigsaw_h
    output_w = grid_size[1] * jigsaw_w
    x_image = x_image.permute(0, 3, 1, 4, 2, 5).contiguous()
    x_image = x_image.view(batch_size, c, output_h, output_w)
    return x_image

2 Likes