I have a custom I3d model and want to convert to torchscript so that it can be used with Deepstream.
I have tried jit.trace but got this error:
x = torch.ones((1, 3, 64, 224, 224)).cuda()
traced_script_module = torch.jit.trace(net, x)
where net is my custom I3D model.
Dictionary inputs to traced functions must have consistent type. Found Tensor and Dict[str, Tensor]
I have tried jit.script method as well and got this:
traced_script_module = torch.jit.script(net)
Traceback (most recent call last):
File "convert_i3d.py", line 50, in <module>
traced_script_module = torch.jit.script(net)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1516, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 318, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 330, in create_script_module_impl
stubs = stubs_fn(nn_module)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 535, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 49, in make_stub_from_method
return make_stub(func, method_name)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 34, in make_stub
ast = torch.jit.get_jit_def(func, name, self_name="RecursiveScriptModule")
File "/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py", line 185, in get_jit_def
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py", line 219, in build_def
build_stmts(ctx, body))
File "/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py", line 126, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py", line 126, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/usr/local/lib/python3.6/dist-packages/torch/jit/frontend.py", line 192, in __call__
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "/usr/local/lib/python3.6/dist-packages/drishti/nnh/inferencers/i3d.py", line 390
end_points["logits_pre_reshape"] = out_logits
def squeeze_and_permute(tensor):
~~~ <--- HERE
# [N, 8k, 8, H, W] -> [N, 8, H, W, 8k]
tensor = tensor.permute([0, 2, 3, 4, 1])