How to compute matrix multiplication of multidimensional matrices efficiently

Regular matrix multiplication:

  • If I have N1 samples and N2 samples, their dimensions are both D.
    X1 = [N1,D], X2 = [N2,D]
    Calculate the similarity matrix between samples, I can use S = X2.mm(X1.T), where S = [N2,N1]

  • But if X1 = [B,N1,D], X2 = [B,N2,D], and the B notes bathsize,
    If I want batch-wise calculation of matrix multiplication,
    the below is a for-loop version,
    how to do it efficiently?

def MatrixMulti(A,B):
    """
    The first dimension is independent 
    and computes matrix multiplication 
    in the remaining dimensions
    inputs:
        A : [b,N1,D]
        B : [b,N2,D]
    output:
        S : [b,N2,N1]
    """

    b,N1,D = A.size()
    b,N2,D = B.size()
    S = torch.zeros((b,N2,N1))
    for i in range(b):
        # [Rb,d] * [d,Ra]
        S[i,:,:] = B[i,:,:].mm(A[i,:,:].transpose(0,1))
    return S

Hi Fly!

Use torch.bmm() (“batch matrix-matrix”), together with torch.transpose()
to line up the D dimensions properly:

S = torch.bmm (X1, torch.transpose (X2, 1, 2))

Best.

K. Frank

1 Like

That is brilliant thank you!