Vectorising/Optimising Grid Image Reassembly

So I have some code that takes a tensor of shape (4,3,16,16) and reassembles this tensor into a shape (3,32,32) by grabbing the elements in dimension 0 and placing them in this order:

[[0, 1],
 [2, 3]]

which is essentially like grabbing a bunch of stacked jigsaw pieces and putting them back together. The code looks like this right now

def reassemble_jigsaw(patches):
    """
    Reassemble an image given the jigsaw patches.
    :param patches: The jigsaw image patches (not shuffled)
    :return: Image with all the image patches put together in the order they are given.
    """
    grid_length = int(np.sqrt(len(patches)))
    patch_row_length = patches[0].shape[1]
    patch_col_length = patches[0].shape[2]
    reassembled = torch.zeros((3, patch_row_length * grid_length, patch_col_length * grid_length))
    for row in range(grid_length):
        for col in range(grid_length):
            reassembled[:,
            patch_row_length * row:patch_row_length * (row + 1),
            patch_col_length * col:patch_col_length * (col + 1)] = patches[row * grid_length + col]

    return reassembled

Unfortunately, I know for a fact that this is slow because it uses loops. Of course, there might not be many iterations, but the slow speed is very noticeable when reassembling many images back-to-back.

I tried looking everywhere for a solution and I found this python - How can I fold a Tensor that I unfolded with PyTorch that has overlap? - Stack Overflow, but I couldn’t figure out how to use it in my application, and I am not sure if it is any quicker. Could I please know if there is a way to perform the same task but faster?

Bump, still waiting for a response.

Using (un)fold is the right approach and you will also find a few examples in this forum posted by me and others. You didn’t explain where you are stuck using this approach, unfortunately.