Efficient batch dot product

I have given a batch of row vectors stored in the matrix U, a batch of column vectors stored in the matrix V and a single matrix M. For each row vector u in U and each column vector v in V I want to compute the sum of the matrix product u *M*v for each batch.

How can I efficiently implement this (potentially using bmm(), matmul() or maybe even einsum)?

Here is a small toy example doing what I want to do with a for loop:

import torch

U = torch.arange(1,10).reshape(3,3)
V = torch.arange(1,10).reshape(3,3)
M = torch.tensor([1, 2, 3]).repeat(3,1)

result = 0
for u,v in zip(U.t(), V):
    result += torch.matmul(torch.matmul(u,V),v)
result:
tensor(4545)

I know there is torch.bmm() to perform batch matrix matrix multiplication. If there was something similar for a batch vector dot product (e.g. torch.bvv()) I could do bvv(matmul(U,M),V) .

Isn’t it nothing but twice bmm()?

import torch
U = torch.arange(1,10).reshape(1, 3,3) 
V = torch.arange(1,10).reshape(1, 3,3) 
M = torch.tensor([1, 2, 3]).repeat(3,1).view(1,3,3) 
result = torch.bmm(U, M).bmm(V)
result: 
tensor([[[ 180,  216,  252],
         [ 450,  540,  630],
         [ 720,  864, 1008]]])

I don’t think that is correct. Note that the result should be a scalar.

((U.t() @ V) * V).sum()

(EDIT: The @ operator is equivalent to torch.matmul)

1 Like

In that case, you can call .sum() on result to get a scalar. Basically, every element of result matrix contains the multiplication result of a combination of u, M, v.
Maybe, you can work out the math and the result.sum() to see if it is correct.

torch.einsum('bi,ij,bj', U, M, V) if you want the sum, 'bi,ij,bj->b' if you prefer the batch items separately. :slight_smile:

Best regards

Thomas

3 Likes

Perfect, thanks! Einsum is really neat, I took the time to get familiar with it and came up with the same result.

I faced a similar problem and noticed that a faster way of doing torch.einsum('bi,ij,bj->b', U, M, V) is torch.sum(U @ M * V, dim=1).

For this particular case, the specialized bilinear function might be a good thing to use.

1 Like

Thanks for pointing out the bilinear function, it’s very useful and I wasn’t aware of it.

Though, in my case, I had to do a bit of squeeze/unsqueeze-ing to get it to work. The dimensions are exactly as in my einsum example, namely (B,M), (B,N), and (M,N) for U, V, and W, respectively. Calling bilinear(U,V,W) in this case requires W to have a dim of (Y,M,N), where Y is the number of out-features, so I had to call it like this bilinear(U,V,W.unsqueeze(0)).squeeze() to make it equivalent to torch.sum(U @ W * V, dim=1).

In addition to that, torch.sum(U @ W * V, dim=1) seems to be about ~1.5 times faster than bilinear(U,V,W.unsqueeze(0)).squeeze() in my case, though I’m not sure why.

(U * V).sum(axis=-1)

1 Like