Batch Outer Product

Title says it all. Is there a way to compute a batch outer product. I noticed that pytorch conveniently has torch.ger which takes in two one-dimensional vectors and outputs there outer-product: (1, n) * (1, m) -> (n, m)

Is there a batch version of this operation? (b, n) * (b, m) -> (b, n, m)

5 Likes

No, there is no batch version for ger. But you can obtain the same results by appending 1 to the tensors.

x = torch.rand(10, 3)
y = torch.rand(10, 3)

res = torch.bmm(x.unsqueeze(2), y.unsqueeze(1))
14 Likes

I guess the following line may help:

reduce(lambda a, b: torch.ger(a,b), list_of_vectors)

Looks like you can also do it with einsum:

einsum = torch.einsum('bi,bj->bij', (x, y))

Full example:

import torch
x = torch.ones(4, 3)
y = torch.rand(4, 3)
bmm = torch.bmm(x.unsqueeze(2), y.unsqueeze(1)) # Method of @fmassa
einsum = torch.einsum('bi,bj->bij', (x, y)) # Method with einsum
bmm==einsum # gives a whole bunch of ones
7 Likes