# Efficient batch dot product

I have given a batch of row vectors stored in the matrix `U`, a batch of column vectors stored in the matrix `V` and a single matrix `M`. For each row vector `u in U` and each column vector `v in V` I want to compute the sum of the matrix product `u *M*v` for each batch.

How can I efficiently implement this (potentially using bmm(), matmul() or maybe even einsum)?

Here is a small toy example doing what I want to do with a for loop:

``````import torch

U = torch.arange(1,10).reshape(3,3)
V = torch.arange(1,10).reshape(3,3)
M = torch.tensor([1, 2, 3]).repeat(3,1)

result = 0
for u,v in zip(U.t(), V):
result += torch.matmul(torch.matmul(u,V),v)
``````
``````result:
tensor(4545)
``````

I know there is `torch.bmm()` to perform batch matrix matrix multiplication. If there was something similar for a batch vector dot product (e.g. `torch.bvv()`) I could do `bvv(matmul(U,M),V)` .

Isn’t it nothing but twice `bmm()`?

``````import torch
U = torch.arange(1,10).reshape(1, 3,3)
V = torch.arange(1,10).reshape(1, 3,3)
M = torch.tensor([1, 2, 3]).repeat(3,1).view(1,3,3)
result = torch.bmm(U, M).bmm(V)
``````
``````result:
tensor([[[ 180,  216,  252],
[ 450,  540,  630],
[ 720,  864, 1008]]])
``````

I don’t think that is correct. Note that the `result` should be a scalar.

``````((U.t() @ V) * V).sum()
``````

(EDIT: The `@` operator is equivalent to torch.matmul)

1 Like

In that case, you can call `.sum()` on `result` to get a scalar. Basically, every element of `result` matrix contains the multiplication result of a combination of `u, M, v`.
Maybe, you can work out the math and the `result.sum()` to see if it is correct.

`torch.einsum('bi,ij,bj', U, M, V)` if you want the sum, `'bi,ij,bj->b'` if you prefer the batch items separately. Best regards

Thomas

3 Likes

Perfect, thanks! Einsum is really neat, I took the time to get familiar with it and came up with the same result.

I faced a similar problem and noticed that a faster way of doing `torch.einsum('bi,ij,bj->b', U, M, V)` is `torch.sum(U @ M * V, dim=1)`.

For this particular case, the specialized `bilinear` function might be a good thing to use.

1 Like

Thanks for pointing out the `bilinear` function, it’s very useful and I wasn’t aware of it.

Though, in my case, I had to do a bit of squeeze/unsqueeze-ing to get it to work. The dimensions are exactly as in my einsum example, namely (B,M), (B,N), and (M,N) for U, V, and W, respectively. Calling `bilinear(U,V,W)` in this case requires W to have a dim of (Y,M,N), where Y is the number of out-features, so I had to call it like this `bilinear(U,V,W.unsqueeze(0)).squeeze()` to make it equivalent to `torch.sum(U @ W * V, dim=1)`.

In addition to that, `torch.sum(U @ W * V, dim=1)` seems to be about ~1.5 times faster than `bilinear(U,V,W.unsqueeze(0)).squeeze()` in my case, though I’m not sure why.

(U * V).sum(axis=-1)

1 Like