Regular matrix multiplication:
-
If I have N1 samples and N2 samples, their dimensions are both D.
X1 = [N1,D], X2 = [N2,D]
Calculate the similarity matrix between samples, I can useS = X2.mm(X1.T)
, where S = [N2,N1] -
But if X1 = [B,N1,D], X2 = [B,N2,D], and the B notes bathsize,
If I want batch-wise calculation of matrix multiplication,
the below is a for-loop version,
how to do it efficiently?
def MatrixMulti(A,B):
"""
The first dimension is independent
and computes matrix multiplication
in the remaining dimensions
inputs:
A : [b,N1,D]
B : [b,N2,D]
output:
S : [b,N2,N1]
"""
b,N1,D = A.size()
b,N2,D = B.size()
S = torch.zeros((b,N2,N1))
for i in range(b):
# [Rb,d] * [d,Ra]
S[i,:,:] = B[i,:,:].mm(A[i,:,:].transpose(0,1))
return S