I want to calculate FLOPS of my model for every epoch. I’ve come across few posts and github issues that discuss this but I’m not sure if they are calculating it correctly. I need some suggestions ? If this helps, my model is a transformer.
@ptrblck Could you please recommend something?
Unfortunately I cannot recommend a particular lib or method, but what are your current concerns about the code you’ve found so far?
They approximate and provide better support for (Conv+Linear). I don’t have any conv layers, and I want the approach to be reliable, as to not report inconsistent results due to implementation. The feature request to provide official support from 2018 is still open.
I looked in more carefully, their implementation does not account for several modules like nn.Embedding
and LayerNorm
, GeLU()
etc. These approaches may not give an accurate count of FLOPS. I’m referring to the ones Soumith recommended here
FLOP count is a property of an algorithm rather than a model. Does Linear layer have 2mqp or mq(2p-1) FLOPs? Depends how matmul is performed – see discussion here. You can get an approximate count by assuming some reference implementation.
nn.Embedding
is a dictionary lookup, so technically it has 0 FLOPS.
Since FLOP count is going to be approximate anyway, you only care about the heaviest to compute layers. You could profile your model and see if there are any expensive layers not covered already. TensorFlow has some reference formulas here
Thanks for your response. Also could you please clarify the relation between MACs and FLOPS?
I guess this answers your question
Hi @kl_divergence, I am thinking the same with y ou and I found it: https://github.com/facebookresearch/SlowFast/blob/0cc82440fee6e51a5807853b583be238bf26a253/slowfast/utils/misc.py#L106
It is official account of fbresearch so I guess we can rely on it (?). Have a look.
Our team at Facebook AI computer vision has released a tool to compute and summarize the flop count of any pytorch model: fvcore/flop_count.md at master · facebookresearch/fvcore · GitHub. Please check it out!
Thanks for the nice work.
FLOPS count for the spectral normalization is also supported?
(torch.nn.utils.spectral_norm — PyTorch 1.8.1 documentation)
This normalization is by default ignored but it should take negligible flops anyway.
@ppwwyyxx How do I find the FLOPS of a pre trained pytorch object detection model?
from fvcore.nn import FlopCountAnalysis
import torchvision
from PIL import Image
import torch
import torchvision.transforms as T
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
img_path = "/content/drive/MyDrive/dataset/images/fr-1751.jpg"
img = Image.open(img_path)
transform = T.Compose([T.ToTensor()])
img = transform(img)
# pred = model([img])
flops = FlopCountAnalysis(model, [img])
flops.total()
I get this error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-56-a72097dab463> in <module>()
12 # pred = model([img])
13 flops = FlopCountAnalysis(model, [img])
---> 14 flops.total()
15
16 flops.by_operator()
9 frames
/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in total(self, module_name)
246 int : The aggregated statistic.
247 """
--> 248 stats = self._analyze()
249 module_name = self.canonical_module_name(module_name)
250 total_count = sum(stats.counts[module_name].values())
/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in _analyze(self)
549 elif self._warn_trace == "no_tracer_warning":
550 warnings.filterwarnings("ignore", category=TracerWarning)
--> 551 graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
552
553 # Assures even modules not in the trace graph are initialized to zero count
/usr/local/lib/python3.7/dist-packages/fvcore/nn/jit_analysis.py in _get_scoped_trace_graph(module, inputs, aliases)
174 register_hooks(mod, name)
175
--> 176 graph, _ = _get_trace_graph(module, inputs)
177
178 for handle in hook_handles:
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1164 if not isinstance(args, tuple):
1165 args = (args,)
-> 1166 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1167 return outs
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in forward(self, *args)
130 _create_interpreter_name_lookup_fn(),
131 self.strict,
--> 132 self._force_outplace,
133 )
134
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py in wrapper(*args)
116 if self._return_inputs_states:
117 inputs_states.append(_unflatten(in_args, in_desc))
--> 118 outs.append(self.inner(*trace_inputs))
119 if self._return_inputs_states:
120 inputs_states[0] = (inputs_states[0], trace_inputs)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1126 input = bw_hook.setup_input_hook(input)
1127
-> 1128 result = forward_call(*input, **kwargs)
1129 if _global_forward_hooks or self._forward_hooks:
1130 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1096 recording_scopes = False
1097 try:
-> 1098 result = self.forward(*input, **kwargs)
1099 finally:
1100 if recording_scopes:
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
59 """
60 if self.training and targets is None:
---> 61 raise ValueError("In training mode, targets should be passed")
62 if self.training:
63 assert targets is not None
ValueError: In training mode, targets should be passed
The error said it very well:
ValueError: In training mode, targets should be passed
Set model.eval()
in order to make predictions with the model.
It did work. Thank you
could you please let me know what is the formula behind the flops calculation of batch normalization layer?
I have tried multiple tools including fvcore, and many of them didn’t work for several submodules of my model.
Then I found deepspeed and it works seamlessly with a very detailed analysis of each submodule.
Really appreciate their work