Weighted Sum of Matrices

I have 2 matrices A, B. A has a size of (X, 500, 8) and B has the size of (X, 8, 1). So A basically is X records of shape (500,8). B holds the weights for each x in X. I would like to calculate the weighted sum for each x so that I end up with a vector (500,1) for each x. The final result should be (X, 500).

I would like to think of it as a group of dot products of shapes: (500x8), (8x1). Any idea how to do that in pytorch?


Batch MatMul could be what you mean, so each matix gets multiplied with its equivalent.

Exactly what I wanted, thank you!