Get the matmul operations in a net?

There’s code such as [1] which walks through a net’s nn.Modules, checks their type, and calculates their FLOPs.

However, some functions in pytorch are defined only in forward() and not in init(). For example, using the methodology of [1], we would get the FLOPs of torch.nn.Linear, but we would miss any instances of torch.matmul. Does anyone have an idea for how to write a function that takes a net and figures out the FLOPs used by torch.matmul operations?


…I ended up making nn.Module wrappers for anonymous functions such as MatMul for my net. Not the most elegant solution, but it worked for my use-case.

I had seen a code aiming to convert PyTorch to Caffe. It looks like

raw_mul = torch.Tensor.__mul__
def mul(input, *args):
    x = raw_mul(input, *args)
    do something
torch.Tensor.__mul__ = mul

Not the most elegant solution.

That makes sense too.