Hi! Just recently I have tried to use einsum
… I have come to notice that one of the following forms of sum-product is actually being processed faster?
My matrix multiplication is done in order to achieve the following
-
I want to calculate the matrix multiplication: B = A^T * A where A^T is NxHxD… and A = DxHxN… I hope to get the result B in the form of B = NxNxH.
-
Then, I will select the feature element of size 1xH from the upper triangular elements on the NxN
coordinates of B.
- The question is why the following 1st approach is faster when N becomes larger than 100? when actually the 2nd approach has less number of flops. Did I misunderstand anything ?
The 1st approach, which is quite straight forward and is the faster approach, is as follows:
A_ = [DxHxN]
B1 = torch.einsum('dhn,dhm->hnm', A_, A_) # B1 = [HxNxN]
edge_index_j, edge_index_i = edge_index # edge_index = [2x #Edges]]
# where #Edges=N^2/2
B1 = B1[:,edge_index_j,edge_index_i] # B1 = [Hx#Edges]
B1 = B1.permute(1, 0) # B1 = [#EdgesxH]
The 2nd slower approach is
A = A.permute(2,1,0) # A = [NxHxD]
A_i = torch.index_select(A, 0, edge_index_i) # A_i = [#EdgexHxD]
A_j = torch.index_select(A, 0, edge_index_j) # A_j = [#EdgexHxD]
B2 = torch.einsum("ehd,ehd->eh", A_i, A_j) # B2 = [#EdgexH]
As N >= 100, I can measure less runtime from the 1st approach than the 2nd approach (here I set N = 2000).
Runtime (1st): 1.312180 ms
Runtime (2nd): 32.030568 ms
Here, I measured the runtime with torch.cuda.Event.
However, from the 1st approach, the calculation of B1 is NxHxDxN = N^2 X H X D, which is twice more than the 2nd approach.
In the 2nd approach, the number of flops for calculating B2 is #EdgexHxD. where #Edge = N^2/2…