How to compute outer product between vectors with broadcast and batch?

Is there any operation that can achieve the following:

import torch

batch_size = 2
seq_len = 2
dim = 3
# batch of squences of embedding vecs:
x = torch.rand([batch_size, seq_len, dim])
# batch of target embedding vecs:
y = torch.rand([batch_size, dim])

# the computation I want to achieve:
print(torch.outer(x[0][0], y[0]))
print(torch.outer(x[0][1], y[0]))
print(torch.outer(x[1][0], y[1]))
print(torch.outer(x[1][1], y[1]))
print()

# what I've tried but failed:
#print(torch.einsum('bij, bj->bij', x, y))
# or
#print(torch.einsum('bij, bj->bijk', x, y))

Hi Yu-Qing!

torch.einsum ('ijk, il -> ijkl', x, y) should do what you want:

>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> batch_size = 2
>>> seq_len = 2
>>> dim = 3
>>> x = torch.rand([batch_size, seq_len, dim])
>>> y = torch.rand([batch_size, dim])
>>>
>>> z = torch.empty (batch_size, seq_len, dim, dim)
>>>
>>> z[0, 0] = torch.outer(x[0][0], y[0])
>>> z[0, 1] = torch.outer(x[0][1], y[0])
>>> z[1, 0] = torch.outer(x[1][0], y[1])
>>> z[1, 1] = torch.outer(x[1][1], y[1])
>>>
>>> torch.einsum ('ijk, il -> ijkl', x, y).equal (z)
True
>>>
>>> print(torch.outer(x[0][0], y[0]))
tensor([[0.3167, 0.2910, 0.2960],
        [0.7378, 0.6779, 0.6894],
        [0.6072, 0.5579, 0.5674]])
>>> print(torch.outer(x[0][1], y[0]))
tensor([[0.3050, 0.2802, 0.2850],
        [0.0210, 0.0193, 0.0196],
        [0.2876, 0.2642, 0.2687]])
>>> print(torch.outer(x[1][0], y[1]))
tensor([[0.4431, 0.4939, 0.6687],
        [0.4362, 0.4863, 0.6584],
        [0.2593, 0.2891, 0.3914]])
>>> print(torch.outer(x[1][1], y[1]))
tensor([[0.3535, 0.3940, 0.5335],
        [0.3736, 0.4165, 0.5639],
        [0.1286, 0.1433, 0.1941]])
>>>
>>> torch.einsum ('ijk, il -> ijkl', x, y)
tensor([[[[0.3167, 0.2910, 0.2960],
          [0.7378, 0.6779, 0.6894],
          [0.6072, 0.5579, 0.5674]],

         [[0.3050, 0.2802, 0.2850],
          [0.0210, 0.0193, 0.0196],
          [0.2876, 0.2642, 0.2687]]],


        [[[0.4431, 0.4939, 0.6687],
          [0.4362, 0.4863, 0.6584],
          [0.2593, 0.2891, 0.3914]],

         [[0.3535, 0.3940, 0.5335],
          [0.3736, 0.4165, 0.5639],
          [0.1286, 0.1433, 0.1941]]]])

Best.

K. Frank

1 Like