How to get the follwing output?
Inputs:
bx1 = torch.tensor([[1,2],[3,4],[5,6],[7,8]])
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
bx2 = torch.tensor([[11,22],[33,44]])
tensor([[11, 22],
[33, 44]])
OUtput:
tensor([[[ 1, 2, 11, 22],
[ 1, 2, 33, 44]],
[[ 3, 4, 11, 22],
[ 3, 4, 33, 44]],
[[ 5, 6, 11, 22],
[ 5, 6, 33, 44]],
[[ 7, 8, 11, 22],
[ 7, 8, 33, 44]]])