Flop Count of einsum operator in pytorch?

How to calculate the flop count for einsum in pytorch? I am having a hard time doing it. I can’t convert each einsum in my code to matmul since the code involves high dimensional and multiple matrix multiplications.
Any help would be highly appreciated !!!

1 Like