I was using from fvcore.nn import FlopCountAnalysis
to calculate the FLOPS for PyTorch model, and it worked very well.
But this time, I received a JIT model from someone else and I loaded it using network = torch.jit.load("model.pth")
. It seems FlopCountAnalysis
will crash this time with the following error messages:
File “/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py”, line 248, in total
stats = self._analyze()
File “/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py”, line 551, in _analyze
graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
File “/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py”, line 174, in _get_scoped_trace_graph
register_hooks(mod, name)
File “/lib/python3.8/site-packages/fvcore/nn/jit_analysis.py”, line 157, in register_hooks
prehook = mod.register_forward_pre_hook(ScopePushHook(name))
File “/lib/python3.8/site-packages/torch/jit/_script.py”, line 942, in fail
raise RuntimeError(name + " is not supported on ScriptModules")
RuntimeError: register_forward_pre_hook is not supported on ScriptModules
I know it must be due to the difference between a normal PyTorch model and a JIT loaded model. Could someone please suggest some method to evaluate the FLOPS of this model? Thanks.