Batch matrix vector multiplication without for loop

Hi, I want to do batch matrix-vector multiplication but cannot figure out how to do that.

for example,
input shape is (N x M x VectorSize), weight shape is (M x VectorSize x VectorSize). and I want to get an output shape (N x M x VectorSize).

N is batch size, M is number of vectors and VectorSize is literally size of vector. I want to compute matrix multiplication each vectors and matrices like picture below.

How can I do this without any for-loop?

Thank you.

N, M, V=2, 3, 5
a = torch.randn(N, M, V)
b = torch.randn(M, V, V)
a_expand = a.unsqueeze(-2)
b_expand = b.expand(N, -1, -1, -1)

c = torch.matmul(a_expand, b_expand).squeeze(-2)

Try torch.bmm

N, M, V = 2, 3, 4
X = torch.rand(N, M, V)
W = torch.rand(N, V, V)

# loop implementation
res = torch.zeros(N, M, V)
for i in range(N):
    res[i, ...] = torch.matmul(X[i, ...], W[i, ...])

print(res)

# no loop implementation
res2 = torch.bmm(X, W)
print(res2)

Thank you!! It helped me a lot :smiley: