How to properly use "einsum" ? It seems to operate faster in some settings

Hi! Just recently I have tried to use einsum … I have come to notice that one of the following forms of sum-product is actually being processed faster?

My matrix multiplication is done in order to achieve the following

  1. I want to calculate the matrix multiplication: B = A^T * A where A^T is NxHxD… and A = DxHxN… I hope to get the result B in the form of B = NxNxH.

  2. Then, I will select the feature element of size 1xH from the upper triangular elements on the NxN
    coordinates of B.

  • The question is why the following 1st approach is faster when N becomes larger than 100? when actually the 2nd approach has less number of flops. Did I misunderstand anything ?

The 1st approach, which is quite straight forward and is the faster approach, is as follows:

A_      = [DxHxN]

B1 =  torch.einsum('dhn,dhm->hnm', A_, A_)     # B1  = [HxNxN]

edge_index_j, edge_index_i = edge_index #  edge_index = [2x #Edges]] 
# where #Edges=N^2/2

B1 = B1[:,edge_index_j,edge_index_i]     # B1  = [Hx#Edges]
B1 = B1.permute(1, 0)                    # B1  = [#EdgesxH]  

The 2nd slower approach is

A   = A.permute(2,1,0)                         # A  = [NxHxD]
A_i = torch.index_select(A, 0, edge_index_i)   # A_i = [#EdgexHxD]
A_j = torch.index_select(A, 0, edge_index_j)   # A_j = [#EdgexHxD]

B2 = torch.einsum("ehd,ehd->eh", A_i, A_j) # B2 = [#EdgexH]

As N >= 100, I can measure less runtime from the 1st approach than the 2nd approach (here I set N = 2000).

Runtime (1st): 1.312180 ms
Runtime (2nd): 32.030568 ms

Here, I measured the runtime with torch.cuda.Event.

However, from the 1st approach, the calculation of B1 is NxHxDxN = N^2 X H X D, which is twice more than the 2nd approach.

In the 2nd approach, the number of flops for calculating B2 is #EdgexHxD. where #Edge = N^2/2…