Here’s my attempt.
I just repeated Z and S into (n x l x m) and (n x l x k) tensors respectively, then concatenated along the last dimension.
n = 2
m = 3
k = 4
l = 5
z = torch.rand((n,m)) # n x m
'''
tensor([[0.7424, 0.6298, 0.8356],
[0.5601, 0.1725, 0.9190]])
'''
s = torch.rand((k,l)).T # l x k
'''
tensor([[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]])
'''
z_repeated = z.repeat_interleave(l, dim=0).reshape(n,l,m) # n x l x m
'''
tensor([[[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356],
[0.7424, 0.6298, 0.8356]],
[[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190],
[0.5601, 0.1725, 0.9190]]])
'''
s_repeated = s.repeat((n,1,1)) # n x l x k
'''
tensor([[[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]],
[[0.5839, 0.9947, 0.7128, 0.2356],
[0.4745, 0.8217, 0.4955, 0.7252],
[0.7342, 0.0209, 0.8368, 0.1695],
[0.4658, 0.3710, 0.6414, 0.1134],
[0.7196, 0.9393, 0.5347, 0.1267]]])
'''
out = torch.cat((z_repeated, s_repeated), dim=-1) # n x l x (m+k)
'''
tensor([[[0.7424, 0.6298, 0.8356, 0.5839, 0.9947, 0.7128, 0.2356],
[0.7424, 0.6298, 0.8356, 0.4745, 0.8217, 0.4955, 0.7252],
[0.7424, 0.6298, 0.8356, 0.7342, 0.0209, 0.8368, 0.1695],
[0.7424, 0.6298, 0.8356, 0.4658, 0.3710, 0.6414, 0.1134],
[0.7424, 0.6298, 0.8356, 0.7196, 0.9393, 0.5347, 0.1267]],
[[0.5601, 0.1725, 0.9190, 0.5839, 0.9947, 0.7128, 0.2356],
[0.5601, 0.1725, 0.9190, 0.4745, 0.8217, 0.4955, 0.7252],
[0.5601, 0.1725, 0.9190, 0.7342, 0.0209, 0.8368, 0.1695],
[0.5601, 0.1725, 0.9190, 0.4658, 0.3710, 0.6414, 0.1134],
[0.5601, 0.1725, 0.9190, 0.7196, 0.9393, 0.5347, 0.1267]]])
'''