How to calculate FLOPS of a JIT model?

I was using from fvcore.nn import FlopCountAnalysis to calculate the FLOPS for PyTorch model, and it worked very well.

But this time, I received a JIT model from someone else and I loaded it using network = torch.jit.load("model.pth"). It seems FlopCountAnalysis will crash this time with the following error messages:

File “/lib/python3.8/site-packages/fvcore/nn/”, line 248, in total
stats = self._analyze()
File “/lib/python3.8/site-packages/fvcore/nn/”, line 551, in _analyze
graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
File “/lib/python3.8/site-packages/fvcore/nn/”, line 174, in _get_scoped_trace_graph
register_hooks(mod, name)
File “/lib/python3.8/site-packages/fvcore/nn/”, line 157, in register_hooks
prehook = mod.register_forward_pre_hook(ScopePushHook(name))
File “/lib/python3.8/site-packages/torch/jit/”, line 942, in fail
raise RuntimeError(name + " is not supported on ScriptModules")
RuntimeError: register_forward_pre_hook is not supported on ScriptModules

I know it must be due to the difference between a normal PyTorch model and a JIT loaded model. Could someone please suggest some method to evaluate the FLOPS of this model? Thanks.

AFAIK this is not super straightforward, the way fvcore works is that it’ll insert some hooks that will observe what an nn.Module is doing to count its FLOPS. torch.jit is under maintenance mode so even if it was possible to support hooks it’s unlikely to get prioritized and its not trivial to convert from a JIT model to a regular PyTorch model so your best bet is to profile the model without jit

1 Like