So I have some code that takes a tensor of shape (4,3,16,16)
and reassembles this tensor into a shape (3,32,32)
by grabbing the elements in dimension 0 and placing them in this order:
[[0, 1],
[2, 3]]
which is essentially like grabbing a bunch of stacked jigsaw pieces and putting them back together. The code looks like this right now
def reassemble_jigsaw(patches):
"""
Reassemble an image given the jigsaw patches.
:param patches: The jigsaw image patches (not shuffled)
:return: Image with all the image patches put together in the order they are given.
"""
grid_length = int(np.sqrt(len(patches)))
patch_row_length = patches[0].shape[1]
patch_col_length = patches[0].shape[2]
reassembled = torch.zeros((3, patch_row_length * grid_length, patch_col_length * grid_length))
for row in range(grid_length):
for col in range(grid_length):
reassembled[:,
patch_row_length * row:patch_row_length * (row + 1),
patch_col_length * col:patch_col_length * (col + 1)] = patches[row * grid_length + col]
return reassembled
Unfortunately, I know for a fact that this is slow because it uses loops. Of course, there might not be many iterations, but the slow speed is very noticeable when reassembling many images back-to-back.
I tried looking everywhere for a solution and I found this python - How can I fold a Tensor that I unfolded with PyTorch that has overlap? - Stack Overflow, but I couldn’t figure out how to use it in my application, and I am not sure if it is any quicker. Could I please know if there is a way to perform the same task but faster?