How to fold the unfolded?

EDITED: added channels (making things even harder :worried:)

Hey there,
I have a tensor of shape [batch, C, H, W] which is a batch of images with 3 channels

batch = torch.rand((2,3,64,64))

I’m manipulating this batch, and my goal is to return to this shape after the manipulation was done.
The manipulation is to slice each image to 16x16 squares and represent each square as a flattened vector, and lastly, since each image was sliced, we’ll take all the slices and represent them as a sequence.

The manipulation is as follow:

patched_squares = batch.unfold(2,16,16).unfold(3,16,16)

which should slice each image-channel to 16x16 squares, so the output shape will be:

torch.Size([2, 3, 4, 4, 16, 16])

meaning, a batch of size 2, 3 channels, with 4 rows and 4 columns that consist of 16x16 squares

next I’m flattening the 16x16 squares to a vector, like so:

patched_squares = patched_squares.flatten(start_dim=4, end_dim=5)

now the size is:

torch.Size([2, 3, 4, 4, 256])

meaning, a batch of size 2, 3 channels, with 4 rows and 4 columns that consist of 256 flattened vector

Next, I’m flattening the 4x4 to be a sequence:

patched_squares = patched_squares.flatten(start_dim=2, end_dim=3)

which gives a batch of size 2, 3 channels, with sequence length = 16, where each sample in the sequence is a 256 vector

torch.Size([2, 3, 16, 256])

Lastly, I’m concatenating the channels of each “square” like so:

patched_squares = torch.stack([[p[0],p[1],p[2]], dim=1) for p in patched_squares])

torch.Size([2, 16, 768])

The full manipulation code is:

batch = torch.rand((2,3,64,64))
patched_squares = batch.unfold(2,16,16).unfold(3,16,16).flatten(start_dim=4, end_dim=5).flatten(start_dim=2, end_dim=3)
patched_squares = torch.stack([[p[0],p[1],p[2]], dim=1) for p in patched_squares])

And now to the challenge itself :upside_down_face:
After running some computation on patched_squares, I would like to “refold” it back to how it was,
so from [2, 16, 768] back to [2, 3, 64, 64], while maintaining the actual original positions.

I’m having a hard time to refold, I would appreciate any help in this manner.


For future reference, iv’e managed to re-fold it, here is the solution I came up with:

start = torch.rand((2,3,64,64))
forward1 = start.unfold(2,16,16)
forward2 = forward1.unfold(3,16,16)
forward3 = forward2.flatten(start_dim=4, end_dim=5)
forward4 = forward3.flatten(start_dim=2, end_dim=3)
forward5 = torch.stack([[p[0],p[1],p[2]], dim=1) for p in forward4])

# going back
back1 = forward5.unfold(2,256,256).permute(0,2,1,3) 	# back1 == forward4
back2 = back1.unfold(2,4,4).permute(0,1,2,4,3) 			# back2 == forward3
back3 = back2.unfold(4,16,16) 							# back3 == forward2
back4 = back3.permute(0,1,2,3,5,4).reshape(2,3,4,-1,16) # back4 == forward1
back5 = back4.permute(0,1,2,4,3).reshape(2,3,-1,64)		# back5 == start

(back5 == start).all()

I don’t know if it is the most optimal way to do it, but it works :slight_smile: