Get the matmul operations in a net?

There’s code such as [1] which walks through a net’s nn.Modules, checks their type, and calculates their FLOPs.

However, some functions in pytorch are defined only in forward() and not in init(). For example, using the methodology of [1], we would get the FLOPs of torch.nn.Linear, but we would miss any instances of torch.matmul. Does anyone have an idea for how to write a function that takes a net and figures out the FLOPs used by torch.matmul operations?

[1] https://github.com/Lyken17/pytorch-OpCounter

…I ended up making nn.Module wrappers for anonymous functions such as MatMul for my net. Not the most elegant solution, but it worked for my use-case.

I had seen a code aiming to convert PyTorch to Caffe. It looks like

raw_mul = torch.Tensor.__mul__
def mul(input, *args):
    x = raw_mul(input, *args)
    do something
torch.Tensor.__mul__ = mul

Not the most elegant solution.

That makes sense too.