Tensor concatenation

How to get the follwing output?

Inputs:

bx1 = torch.tensor([[1,2],[3,4],[5,6],[7,8]])
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
bx2 = torch.tensor([[11,22],[33,44]])
tensor([[11, 22],
[33, 44]])

OUtput:

tensor([[[ 1, 2, 11, 22],
[ 1, 2, 33, 44]],

    [[ 3,  4, 11, 22],
     [ 3,  4, 33, 44]],

    [[ 5,  6, 11, 22],
     [ 5,  6, 33, 44]],

    [[ 7,  8, 11, 22],
     [ 7,  8, 33, 44]]])

I think this can work:

import torch
bx1 = torch.tensor([[1,2],[3,4],[5,6],[7,8]])
bx2 = torch.tensor([[11,22],[33,44]])
temp = torch.repeat_interleave(bx1, 2, dim=0)
temp = temp.reshape(4, 2, -1)
output = torch.cat((temp, bx2.repeat(4, 1, 1)), dim=2)
print(output)
tensor([[[ 1,  2, 11, 22],
         [ 1,  2, 33, 44]],

        [[ 3,  4, 11, 22],
         [ 3,  4, 33, 44]],

        [[ 5,  6, 11, 22],
         [ 5,  6, 33, 44]],

        [[ 7,  8, 11, 22],
         [ 7,  8, 33, 44]]])
1 Like