Broadcasting with Bilinear function

Hi there, I would like some help with trying to do a vector outer product, but with a bilinear function. So for example, we have the following vectors:

a = torch.rand(1, 10, 128)  # |a| = (batch, num_vectors, dim)
b = torch.rand(1, 10, 128)  # |b| = (batch, num_vectors, dim)

# In a vector outer product, we could just do
c = torch.bmm(a, b.permute(0, 2, 1))  # |c| = (batch, num_vectors, num_vectors)

# I would like to do the same operation, but using a Bilinear layer
# However, if we use a Bilinear layer, we need to repeat vectors. Meaning it would look like
a = a.repeat_interleave(10, dim=1)
b = b.repeat(1, 10, 1)

bilinear_layer = torch.nn.Bilinear(128, 128, 1)
c = bilinear_layer(a, b) # |c| = (batch, num_vectors * num_vectors, 1)
c = c.view(batch, num_vectors, num_vectors)

The thing is, repeat_interleave and repeat will allocate memory. Is there a way to efficiently broadcast this operation and avoid that?

Hi Legoh!

I do not believe that there is a way to tell (or trick) Bilinear to have
it broadcast internally and I don’t believe that there is a way to reshape
a and b appropriately without allocating memory.

I think the cleanest approach is to use einsum() with Bilinear.weight
to get, in effect, the broadcasting you want:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> batch = 2
>>> num_vectors = 10
>>> dim = 128
>>>
>>> a = torch.rand (batch, num_vectors, dim)  # |a| = (batch, num_vectors, dim)
>>> b = torch.rand (batch, num_vectors, dim)  # |b| = (batch, num_vectors, dim)
>>> aa = a.clone()
>>> bb = b.clone()
>>>
>>> # I would like to do the same operation, but using a Bilinear layer
>>> # However, if we use a Bilinear layer, we need to repeat vectors. Meaning it would look like
>>> a = a.repeat_interleave (num_vectors, dim = 1)
>>> b = b.repeat  (1, num_vectors, 1)
>>>
>>> bilinear_layer = torch.nn.Bilinear (dim, dim, 1)
>>> c = bilinear_layer (a, b) # |c| = (batch, num_vectors * num_vectors, 1)
>>> c = c.view (batch, num_vectors, num_vectors)
>>>
>>> cc = torch.einsum ('ijk, nkm, ilm -> ijl', aa, bilinear_layer.weight, bb) + bilinear_layer.bias
>>> cc.shape
torch.Size([2, 10, 10])
>>> torch.allclose (cc, c)
True

Best.

K. Frank

Thanks Frank! I incorporated your solution and it works well! Thank you so much.