Convert pytorch model to ptl

Hello,
I have saved pytorch model using below method after training

torch.save(model.state_dict(), model_path)

My final goal is to deploy the model on mobile. I am using below code for the purpose

model = Net()
model.load_state_dict(torch.load(model_path, map_location=‘cpu’))
traced_script_module = torch.jit.script(model)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(ptl_model_path)

With this code, I am getting below error. Any fix available for this problem?

File “pytorch_experiment_mnist.py”, line 103, in convert_pth_to_ptl
traced_script_module = torch.jit.script(model)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_script.py”, line 1258, in script
obj, torch.jit._recursive.infer_methods_to_compile
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 451, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 513, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_script.py”, line 587, in _construct
init_fn(script_module)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 491, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 517, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 368, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_recursive.py”, line 838, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/_script.py”, line 1307, in script
ast = get_jit_def(obj, obj.name)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/frontend.py”, line 264, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/frontend.py”, line 302, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File “/home/harshavardhana/ProjectTools/Python_VEnv/venv_pytorch/lib/python3.6/site-packages/torch/jit/frontend.py”, line 337, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can’t take variable number of arguments or use keyword-only arguments with defaults:
File “/usr/lib/python3.6/collections/init.py”, line 357
def namedtuple(typename, field_names, *, verbose=False, rename=False, module=None):
~~~~~ <— HERE
“”"Returns a new subclass of tuple with named fields.

You might be running into this issue which is caused if you are using a variable number of arguments e.g. via def forward(self, *inputs, **kwargs).
Could you check your model definition and see where this might be the case?

I am trying to convert the model preload model in GitHub - clovaai/CRAFT-pytorch: Official implementation of Character Region Awareness for Text Detection (CRAFT)
The code doesn’t seem to have variable number of args directly, but using VGG 16 as basenet.
Not sure about VGG 16.

I wrote a sample code to check your solution and it works.
But the confusion which I still have is,
If I save complete model object, I dont get any error related to variable number of args.
The problem is only when I save state_dict().

I intend to evaluate multiple pytorch model on mobile and most of the available models save only state_dict()

Is there any way I can fix this problem, succeed in optimizing the models to mobile and evaluate them?

Could you describe where you are saving the state_dict? In your current code it seems you are loading a state_dict and then exporting the model. Is this creating the issue and the export works fine if you are not loading the state_dict?

I am trying to convert pretrained model available @ GitHub - clovaai/CRAFT-pytorch: Official implementation of Character Region Awareness for Text Detection (CRAFT) (https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ)
Hence there is no saving part in my code.

I am able to load the model. But when I run torch.jit.script(model), I am running into errors.
Since the model is pretrained, I dont exactly know the args used for training.

I don’t quite understand this, as you seem to have the model definition if you are creating the model object here:

model = Net()

Check where Net is defined and make sure no variable arguments are used.

That was an example code. Let me post my exact code here.
You can refer craft related code @ GitHub - clovaai/CRAFT-pytorch: Official implementation of Character Region Awareness for Text Detection (CRAFT)

import torch
import craft
from collections import OrderedDict

def copy_state_dict(state_dict):
if list(state_dict.keys())[0].startswith(‘module’):
start_index = 1
else:
start_index = 0

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = '.'.join(k.split('.')[start_index:])
    new_state_dict[name] = v

return new_state_dict

model = craft.CRAFT()
model.load_state_dict(copy_state_dict(torch.load(‘craft_mlt_25k.pth’, map_location=‘cpu’)))

traced_script_model = torch.jit.script(model)