I have given a batch of row vectors stored in the matrix U, a batch of column vectors stored in the matrix V and a single matrix M. For each row vector u in U and each column vector v in V I want to compute the sum of the matrix product u *M*v for each batch.
How can I efficiently implement this (potentially using bmm(), matmul() or maybe even einsum)?
Here is a small toy example doing what I want to do with a for loop:
import torch
U = torch.arange(1,10).reshape(3,3)
V = torch.arange(1,10).reshape(3,3)
M = torch.tensor([1, 2, 3]).repeat(3,1)
result = 0
for u,v in zip(U.t(), V):
result += torch.matmul(torch.matmul(u,V),v)
result:
tensor(4545)
I know there is torch.bmm() to perform batch matrix matrix multiplication. If there was something similar for a batch vector dot product (e.g. torch.bvv()) I could do bvv(matmul(U,M),V) .
import torch
U = torch.arange(1,10).reshape(1, 3,3)
V = torch.arange(1,10).reshape(1, 3,3)
M = torch.tensor([1, 2, 3]).repeat(3,1).view(1,3,3)
result = torch.bmm(U, M).bmm(V)
In that case, you can call .sum() on result to get a scalar. Basically, every element of result matrix contains the multiplication result of a combination of u, M, v.
Maybe, you can work out the math and the result.sum() to see if it is correct.
Thanks for pointing out the bilinear function, it’s very useful and I wasn’t aware of it.
Though, in my case, I had to do a bit of squeeze/unsqueeze-ing to get it to work. The dimensions are exactly as in my einsum example, namely (B,M), (B,N), and (M,N) for U, V, and W, respectively. Calling bilinear(U,V,W) in this case requires W to have a dim of (Y,M,N), where Y is the number of out-features, so I had to call it like this bilinear(U,V,W.unsqueeze(0)).squeeze() to make it equivalent to torch.sum(U @ W * V, dim=1).
In addition to that, torch.sum(U @ W * V, dim=1) seems to be about ~1.5 times faster than bilinear(U,V,W.unsqueeze(0)).squeeze() in my case, though I’m not sure why.