So I have some code that takes a tensor of shape
(4,3,16,16) and reassembles this tensor into a shape
(3,32,32) by grabbing the elements in dimension 0 and placing them in this order:
[[0, 1], [2, 3]]
which is essentially like grabbing a bunch of stacked jigsaw pieces and putting them back together. The code looks like this right now
def reassemble_jigsaw(patches): """ Reassemble an image given the jigsaw patches. :param patches: The jigsaw image patches (not shuffled) :return: Image with all the image patches put together in the order they are given. """ grid_length = int(np.sqrt(len(patches))) patch_row_length = patches.shape patch_col_length = patches.shape reassembled = torch.zeros((3, patch_row_length * grid_length, patch_col_length * grid_length)) for row in range(grid_length): for col in range(grid_length): reassembled[:, patch_row_length * row:patch_row_length * (row + 1), patch_col_length * col:patch_col_length * (col + 1)] = patches[row * grid_length + col] return reassembled
Unfortunately, I know for a fact that this is slow because it uses loops. Of course, there might not be many iterations, but the slow speed is very noticeable when reassembling many images back-to-back.
I tried looking everywhere for a solution and I found this python - How can I fold a Tensor that I unfolded with PyTorch that has overlap? - Stack Overflow, but I couldn’t figure out how to use it in my application, and I am not sure if it is any quicker. Could I please know if there is a way to perform the same task but faster?