# Performance difference in torch.einsum

I would like to learn the internals of some PyTorch high-level ops, such as `einsum` and `bmm`. I created a code snippet as follows:

``````B, C_qk, H, N_s, N_t = 10, 128, 1, 32, 64
q = torch.randint(0, 100, (B, C_qk, H, N_s))
k = torch.randint(0, 100, (B, N_t, H, C_qk))

def einsum_1(q, k):
def einsum_2(q, k):

t1 = benchmark.Timer(stmt='einsum_1(q, k)',
setup='from __main__ import einsum_1',
globals={'q': q, 'k': k})

t2 = benchmark.Timer(stmt='einsum_2(q, k)',
setup='from __main__ import einsum_2',
globals={'q': q, 'k': k})

print(t1.timeit(10000))
print(t2.timeit(10000))

setup='from __main__ import einsum_1',
globals={'q': q, 'k': k})

setup='from __main__ import einsum_2',
globals={'q': q, 'k': k})

``````

Here, `t2` (1.28ms) and `t2_threaded` (1.27ms) are consistently slower than `t1` (1.22ms) and `t1_threaded` (1.23ms).

Note that the two einsums yield the same output as they’re equivalent.

``````res_1 = torch.einsum('bchq,bkhc->bkhq', [q, k])
res_2 = torch.einsum('bkhc,bchq->bkhq', [k, q])
assert torch.equal(res_1, res_2), "torch.einsum failed to auto align the summation dim"
``````

What are the reasons behind the performance difference? Is there a tool I can use to trace these high-level ops to see what primitive ops are called/call counts/latency breakdown?

The primitive ops can be reported by `torch.profiler`. But the breakdown doesn’t answer the original question.