Create all possible combinations of a 3D tensor along the dimension number 1

I have a 3D tensor (batch size, m, n), and I would like to create all the possible combinations along the axes number 1 (batch size, m^2, 2n). For example, given the following tensor

input_tensor = torch.Tensor([[[1, 2, 3, 4],
         [5, 6, 7, 8]],

        [[9, 10, 11, 12],
         [13, 14, 15, 16]],

        [[17, 18, 19, 20],
         [21, 22, 23, 24]]])

with size torch.Size([3, 2, 4]).

I would like to have the following output tensor

output_tensor = torch.Tensor([[[1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 5, 6, 7, 8], [5, 6, 7, 8, 1, 2, 3, 4]],

        [[9, 10, 11, 12, 9, 10, 11, 12], [9, 10, 11, 12, 13, 14, 15, 16], [13, 14, 15, 16, 13, 14, 15, 16], [13, 14, 15, 16, 9, 10, 11, 12]],

        [[17, 18, 19, 20, 17, 18, 19, 20], [17, 18, 19, 20, 21, 22, 23, 24], [21, 22, 23, 24, 21, 22, 23, 24], [21, 22, 23, 24, 17, 18, 19, 20]]])

torch.Size([3, 4, 8])
I have tried several things, and the one that went closer is

output_tensor = torch.stack([torch.cat(batch,1) for batch in product(input_tensor, repeat=2)])

… but still not working. I had a look at this link but it did not help.

UPDATE:
I have just managed to get the output_tensor with two nested loops:

output_tensor = torch.stack([torch.stack([torch.cat(t,0) for t in product(row, repeat=2)]) for row in input_tensor])

Is there a way to obtain output_tensor through vectorized operation instead of creating loops?

You can proceed like this:

first = input_tensor.repeat(1, 2, 1)
second = input_tensor.unsqueeze(2)
second = second.repeat(1,1,2,1).view(input_tensor.size(0),-1,input_tensor.size(2))
output_tensor = torch.cat((first,second), dim=2)
2 Likes