Correct way to calculate FLOPS in model

I want to calculate FLOPS of my model for every epoch. I’ve come across few posts and github issues that discuss this but I’m not sure if they are calculating it correctly. I need some suggestions ? If this helps, my model is a transformer.

4 Likes

@ptrblck Could you please recommend something?

Unfortunately I cannot recommend a particular lib or method, but what are your current concerns about the code you’ve found so far?

They approximate and provide better support for (Conv+Linear). I don’t have any conv layers, and I want the approach to be reliable, as to not report inconsistent results due to implementation. The feature request to provide official support from 2018 is still open.

I looked in more carefully, their implementation does not account for several modules like nn.Embedding and LayerNorm, GeLU() etc. These approaches may not give an accurate count of FLOPS. I’m referring to the ones Soumith recommended here

FLOP count is a property of an algorithm rather than a model. Does Linear layer have 2mqp or mq(2p-1) FLOPs? Depends how matmul is performed – see discussion here. You can get an approximate count by assuming some reference implementation.

nn.Embedding is a dictionary lookup, so technically it has 0 FLOPS.

Since FLOP count is going to be approximate anyway, you only care about the heaviest to compute layers. You could profile your model and see if there are any expensive layers not covered already. TensorFlow has some reference formulas here

7 Likes

Thanks for your response. Also could you please clarify the relation between MACs and FLOPS?

I guess this answers your question

2 Likes

Hi @kl_divergence, I am thinking the same with y ou and I found it: https://github.com/facebookresearch/SlowFast/blob/0cc82440fee6e51a5807853b583be238bf26a253/slowfast/utils/misc.py#L106

It is official account of fbresearch so I guess we can rely on it (?). Have a look.

1 Like

Our team at Facebook AI computer vision has released a tool to compute and summarize the flop count of any pytorch model: fvcore/flop_count.md at master · facebookresearch/fvcore · GitHub. Please check it out!

7 Likes

Thanks for the nice work.
FLOPS count for the spectral normalization is also supported?
(torch.nn.utils.spectral_norm — PyTorch 1.8.1 documentation)

This normalization is by default ignored but it should take negligible flops anyway.

1 Like

@ppwwyyxx How do I find the FLOPS of a pre trained pytorch object detection model?

from fvcore.nn import FlopCountAnalysis
import torchvision
from PIL import Image
import torch
import torchvision.transforms as T

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
img_path = "/content/drive/MyDrive/dataset/images/fr-1751.jpg"
img = Image.open(img_path)
transform = T.Compose([T.ToTensor()])
img = transform(img)
# pred = model([img])
flops = FlopCountAnalysis(model, [img])
flops.total()

I get this error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-56-a72097dab463> in <module>()
     12 # pred = model([img])
     13 flops = FlopCountAnalysis(model, [img])
---> 14 flops.total()
     15 
     16 flops.by_operator()

9 frames
/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in total(self, module_name)
    246             int : The aggregated statistic.
    247         """
--> 248         stats = self._analyze()
    249         module_name = self.canonical_module_name(module_name)
    250         total_count = sum(stats.counts[module_name].values())

/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in _analyze(self)
    549             elif self._warn_trace == "no_tracer_warning":
    550                 warnings.filterwarnings("ignore", category=TracerWarning)
--> 551             graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
    552 
    553         # Assures even modules not in the trace graph are initialized to zero count

/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in _get_scoped_trace_graph(module, inputs, aliases)
    174         register_hooks(mod, name)
    175 
--> 176     graph, _ = _get_trace_graph(module, inputs)
    177 
    178     for handle in hook_handles:

/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1164     if not isinstance(args, tuple):
   1165         args = (args,)
-> 1166     outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1167     return outs

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in forward(self, *args)
    130             _create_interpreter_name_lookup_fn(),
    131             self.strict,
--> 132             self._force_outplace,
    133         )
    134 

/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in wrapper(*args)
    116             if self._return_inputs_states:
    117                 inputs_states.append(_unflatten(in_args, in_desc))
--> 118             outs.append(self.inner(*trace_inputs))
    119             if self._return_inputs_states:
    120                 inputs_states[0] = (inputs_states[0], trace_inputs)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1126             input = bw_hook.setup_input_hook(input)
   1127 
-> 1128         result = forward_call(*input, **kwargs)
   1129         if _global_forward_hooks or self._forward_hooks:
   1130             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1096                 recording_scopes = False
   1097         try:
-> 1098             result = self.forward(*input, **kwargs)
   1099         finally:
   1100             if recording_scopes:

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     59         """
     60         if self.training and targets is None:
---> 61             raise ValueError("In training mode, targets should be passed")
     62         if self.training:
     63             assert targets is not None

ValueError: In training mode, targets should be passed

The error said it very well:

ValueError: In training mode, targets should be passed

Set model.eval() in order to make predictions with the model.

It did work. Thank you

could you please let me know what is the formula behind the flops calculation of batch normalization layer?

I have tried multiple tools including fvcore, and many of them didn’t work for several submodules of my model.
Then I found deepspeed and it works seamlessly with a very detailed analysis of each submodule.

Really appreciate their work

4 Likes