Hi, I want to do batch matrix-vector multiplication but cannot figure out how to do that.
for example,
input shape is (N x M x VectorSize), weight shape is (M x VectorSize x VectorSize). and I want to get an output shape (N x M x VectorSize).
N is batch size, M is number of vectors and VectorSize is literally size of vector. I want to compute matrix multiplication each vectors and matrices like picture below.
N, M, V=2, 3, 5
a = torch.randn(N, M, V)
b = torch.randn(M, V, V)
a_expand = a.unsqueeze(-2)
b_expand = b.expand(N, -1, -1, -1)
c = torch.matmul(a_expand, b_expand).squeeze(-2)
N, M, V = 2, 3, 4
X = torch.rand(N, M, V)
W = torch.rand(N, V, V)
# loop implementation
res = torch.zeros(N, M, V)
for i in range(N):
res[i, ...] = torch.matmul(X[i, ...], W[i, ...])
print(res)
# no loop implementation
res2 = torch.bmm(X, W)
print(res2)