Advanced multiplication of 2D tensors to get a 3D tensor

Let x and emb be 2D matrices of size (bsz, n) and (n, m) respectively.

x = torch.FloatTensor([[0,1,2], [3,4,5]])
emb = torch.FloatTensor([[0,1,2,3], [4,5,6,7], [8,9,10,11]])

# x
# tensor([[ 0.,  1.,  2.],
#         [ 3.,  4.,  5.]])

# emb
# tensor([[  0.,   1.,   2.,   3.],
#         [  4.,   5.,   6.,   7.],
#         [  8.,   9.,  10.,  11.]])

I want the result to be a 3D tensor of size (bsz, n, m) where out[j, i, :] = x[j, i] * emb[i, :]. I am using a loop for now as below but I thought there might be a better way?

out = torch.zeros(bsz, n, m)
for i in range(bsz):
    out[i] = x[i].view(-1, 1) * emb
    
# out
# tensor([[[  0.,   0.,   0.,   0.],
#          [  4.,   5.,   6.,   7.],
#          [ 16.,  18.,  20.,  22.]],
#
#         [[  0.,   3.,   6.,   9.],
#          [ 16.,  20.,  24.,  28.],
#          [ 40.,  45.,  50.,  55.]]])​

I like to use indexing and broadcasting for this:

out = x[:, :, None] * emb[None, :, :]

Indexing with None inserts a new dimension (“unsqueeze” or “view” could achieve the same, but I like to use the above to help me see the alignment).
For more advanced uses, torch.einsum can be useful, too. Here, it would be
torch.einsum('bn,nm->bnm', [x, emb]). Here the indices replicate your indexing (except that I use the dimension indices).

Best regards

Thomas

1 Like