There’s code such as [1] which walks through a net’s nn.Modules, checks their type, and calculates their FLOPs.
However, some functions in pytorch are defined only in forward() and not in init(). For example, using the methodology of [1], we would get the FLOPs of torch.nn.Linear, but we would miss any instances of torch.matmul. Does anyone have an idea for how to write a function that takes a net and figures out the FLOPs used by torch.matmul operations?