Concatenation operation

Hi!

I am working with a model in which I would like to perform the following operation:

Z is an n x m matrix
S is an k x l matrix

I am searching for a pytorch operation that would allow me to concatenate each pair of rows in Z and columns in S to yield a 3D tensor, where entry i,j contains row i in z concatenated with column j in S and hence has length (m + k). Hence, the tensor dimensions would be n x l x (m + k). Is there a convenient way to do this in pytorch?

Thx!

Here’s my attempt.
I just repeated Z and S into (n x l x m) and (n x l x k) tensors respectively, then concatenated along the last dimension.

``````n = 2
m = 3
k = 4
l = 5
z = torch.rand((n,m)) # n x m
'''
tensor([[0.7424, 0.6298, 0.8356],
[0.5601, 0.1725, 0.9190]])
'''
s = torch.rand((k,l)).T # l x k
'''
tensor([[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]])
'''

z_repeated = z.repeat_interleave(l, dim=0).reshape(n,l,m) # n x l x m
'''
tensor([[[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356]],

[[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190]]])
'''
s_repeated = s.repeat((n,1,1)) # n x l x k
'''
tensor([[[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]],

[[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]]])
'''
out = torch.cat((z_repeated, s_repeated), dim=-1) # n x l x (m+k)
'''
tensor([[[0.7424, 0.6298, 0.8356, 0.5839, 0.9947, 0.7128, 0.2356],
[0.7424, 0.6298, 0.8356, 0.4745, 0.8217, 0.4955, 0.7252],
[0.7424, 0.6298, 0.8356, 0.7342, 0.0209, 0.8368, 0.1695],
[0.7424, 0.6298, 0.8356, 0.4658, 0.3710, 0.6414, 0.1134],
[0.7424, 0.6298, 0.8356, 0.7196, 0.9393, 0.5347, 0.1267]],

[[0.5601, 0.1725, 0.9190, 0.5839, 0.9947, 0.7128, 0.2356],
[0.5601, 0.1725, 0.9190, 0.4745, 0.8217, 0.4955, 0.7252],
[0.5601, 0.1725, 0.9190, 0.7342, 0.0209, 0.8368, 0.1695],
[0.5601, 0.1725, 0.9190, 0.4658, 0.3710, 0.6414, 0.1134],
[0.5601, 0.1725, 0.9190, 0.7196, 0.9393, 0.5347, 0.1267]]])
'''
``````