Calculating flops of a given pytorch model

I am trying to made a function of calculating flops and want to discuss about it. In many papers, I can see the flop numbers, but it is hard to see the details of computing them.

I have some questions:

  • Is it normal to include flops of ReLU, Batch normalization, …?
  • It seems common to consider the spatial dimension. For example, when calculating Conv2d layer, I need to know the image size. What is the common size of the original image size of width and height?
  • I am making a code as follows, but I wonder that there is a more elegant way to compute it. (In my code, it is hard to catch the dynamic spatial dimensions. For example, in ReLU, we don’t know the previous state. )
import torchvision
import re

def get_num_gen(gen):
    return sum(1 for x in gen)

def flops_layer(layer):
    """
    Calculate the number of flops for given a string information of layer.
    We extract only resonable numbers and use them.
    
    Args:
        layer (str) : example
            Linear (512 -> 1000)
            Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    """
    #print(layer)
    idx_type_end = layer.find('(')
    type_name = layer[:idx_type_end]
    
    params = re.findall('[^a-z](\d+)', layer)    
    flops = 1
    
    if layer.find('Linear') >= 0:
        C1 = int(params[0])
        C2 = int(params[1])
        flops = C1*C2
        
    elif layer.find('Conv2d') >= 0:
        C1 = int(params[0])
        C2 = int(params[1])
        K1 = int(params[2])
        K2 = int(params[3])
        
        # image size
        H = 32
        W = 32
        flops = C1*C2*K1*K2*H*W
    
#     print(type_name, flops)
    return flops

def calculate_flops(gen):
    """
    Calculate the flops given a generator of pytorch model.
    It only compute the flops of forward pass.
    
    Example:
        >>> net = torchvision.models.resnet18()
        >>> calculate_flops(net.children())
    """
    flops = 0;
    
    for child in gen:
        num_children = get_num_gen(child.children())
        
        # leaf node
        if num_children == 0:
            flops += flops_layer(str(child))
        
        else:
            flops += calculate_flops(child.children())
    
    return flops

net = torchvision.models.resnet18()
flops = calculate_flops(net.children())
print(flops / 10**9, 'G')
# 11.435429919 G
5 Likes

for resnets, the spatial dimension is 224 height and 224 width. For inceptionv3 it is 299x299.
Generally, since majority of flops are in conv and linear, nflops ~= X might show that you are approximating it, and that is prob sufficient for almost all things.

You are missing some critical issue in your implementation: the dimension of the conv layers input is changing through the model depth.
So i guess the right way to compute number of flops will be by using forward hook.

If someone still needs this, we wrote a small script to do that:

Example of usage:

3 Likes

I just wrote a simple script to calculate the flops,.

https://zhuanlan.zhihu.com/p/33992733
You can search the function named “print_model_parm_flops”

2 Likes

model.apply(fn) combined with module.register_forward_hook(hook) allows for easy tracking of layers, but only works for nn.Module layers (conv, batchnorm, etc). This is effective for the majority of cases, but does not allow for tracking of functional calls, e.g. F.interpolate(...). Is there any way to detect functional calls in a forward pass?

3 Likes

Coming to this rather late, but in case people are interested:

It is possible to directly measure the floating point operation count of models directly using CPU performance monitoring units as an alternative to the approaches which track the FLOPS of each operation. Using the python-papi module this is quite easy to do and the results match the operation counting as implemented by the thop module: see http://www.bnikolic.co.uk/blog/python/flops/2019/10/01/pytorch-count-flops.html for a comparison