I have a 3D tensor (batch size, m, n), and I would like to create all the possible combinations along the axes number 1 (batch size, m^2, 2n). For example, given the following tensor
input_tensor = torch.Tensor([[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[9, 10, 11, 12],
[13, 14, 15, 16]],
[[17, 18, 19, 20],
[21, 22, 23, 24]]])
with size torch.Size([3, 2, 4]).
I would like to have the following output tensor
output_tensor = torch.Tensor([[[1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7, 8], [5, 6, 7, 8, 5, 6, 7, 8], [5, 6, 7, 8, 1, 2, 3, 4]],
[[9, 10, 11, 12, 9, 10, 11, 12], [9, 10, 11, 12, 13, 14, 15, 16], [13, 14, 15, 16, 13, 14, 15, 16], [13, 14, 15, 16, 9, 10, 11, 12]],
[[17, 18, 19, 20, 17, 18, 19, 20], [17, 18, 19, 20, 21, 22, 23, 24], [21, 22, 23, 24, 21, 22, 23, 24], [21, 22, 23, 24, 17, 18, 19, 20]]])
torch.Size([3, 4, 8])
I have tried several things, and the one that went closer is
output_tensor = torch.stack([torch.cat(batch,1) for batch in product(input_tensor, repeat=2)])
… but still not working. I had a look at this link but it did not help.
UPDATE:
I have just managed to get the output_tensor
with two nested loops:
output_tensor = torch.stack([torch.stack([torch.cat(t,0) for t in product(row, repeat=2)]) for row in input_tensor])
Is there a way to obtain output_tensor
through vectorized operation instead of creating loops?