Performance difference in torch.einsum

I would like to learn the internals of some PyTorch high-level ops, such as einsum and bmm. I created a code snippet as follows:

B, C_qk, H, N_s, N_t = 10, 128, 1, 32, 64
q = torch.randint(0, 100, (B, C_qk, H, N_s))
k = torch.randint(0, 100, (B, N_t, H, C_qk))

def einsum_1(q, k):
    return torch.einsum('bchq,bkhc->bkhq', [q, k])
def einsum_2(q, k):
    return torch.einsum('bkhc,bchq->bkhq', [k, q])

t1 = benchmark.Timer(stmt='einsum_1(q, k)',
                     setup='from __main__ import einsum_1',
                     globals={'q': q, 'k': k})

t2 = benchmark.Timer(stmt='einsum_2(q, k)',
                     setup='from __main__ import einsum_2',
                     globals={'q': q, 'k': k})

print(t1.timeit(10000))
print(t2.timeit(10000))

num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t1_threaded = benchmark.Timer(stmt='einsum_1(q, k)',
                              setup='from __main__ import einsum_1',
                              num_threads=num_threads,
                              globals={'q': q, 'k': k})

t2_threaded = benchmark.Timer(stmt='einsum_2(q, k)',
                              setup='from __main__ import einsum_2',
                              num_threads=num_threads,
                              globals={'q': q, 'k': k})

print(t1_threaded.timeit(50000))
print(t2_threaded.timeit(50000))

Here, t2 (1.28ms) and t2_threaded (1.27ms) are consistently slower than t1 (1.22ms) and t1_threaded (1.23ms).

Note that the two einsums yield the same output as they’re equivalent.

res_1 = torch.einsum('bchq,bkhc->bkhq', [q, k])
res_2 = torch.einsum('bkhc,bchq->bkhq', [k, q])
assert torch.equal(res_1, res_2), "torch.einsum failed to auto align the summation dim"

What are the reasons behind the performance difference? Is there a tool I can use to trace these high-level ops to see what primitive ops are called/call counts/latency breakdown?

The primitive ops can be reported by torch.profiler. But the breakdown doesn’t answer the original question.