Hi, I am trying to export a torch model to onnx and jit at the same time for a library (Icefall). Some parts of models have to be changed to export. I am seperating these parts with torch.jit.is_scripting()
and torch.onnx.is_in_onnx_export()
. I can export onnx model, but cannot export jit model. A simple example:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 2)
self.fc2 = nn.Linear(10, 2)
def forward(self, x):
if torch.jit.is_scripting():
x = self.fc1(x)
if torch.onnx.is_in_onnx_export():
x = self.fc2(x)
return x
net = Net()
sample = torch.randn(1, 10)
torch.onnx.export(
net,
(sample),
"model.onnx",
verbose=False,
opset_version=13,
input_names=["x"],
output_names=["out"],
)
net = torch.jit.script(net)
net.save(str("model.pt"))
Onnx export is done successfully, but jit scripting raise error:
UnsupportedNodeError Traceback (most recent call last)
Cell In [9], line 1
----> 1 net = torch.jit.script(net)
2 net.save(str("model.pt"))
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_script.py:1257, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1255 if isinstance(obj, torch.nn.Module):
1256 obj = call_prepare_scriptable_func(obj)
-> 1257 return torch.jit._recursive.create_script_module(
1258 obj, torch.jit._recursive.infer_methods_to_compile
1259 )
1261 if isinstance(obj, dict):
1262 return create_script_dict(obj)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:451, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
449 if not is_tracing:
450 AttributeTypeIsSupportedChecker().check(nn_module)
--> 451 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:517, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
515 # Compile methods if necessary
516 if concrete_type not in concrete_type_store.methods_compiled:
--> 517 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
518 # Create hooks after methods to ensure no name collisions between hooks and methods.
519 # If done before, hooks can overshadow methods that aren't exported.
520 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:368, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
365 property_defs = [p.def_ for p in property_stubs]
366 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 368 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:838, in try_compile_fn(fn, loc)
834 # We don't have the actual scope where the function was defined, but we can
835 # extract the necessary info from the closed over variables on the function
836 # object
837 rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
--> 838 return torch.jit.script(fn, _rcb=rcb)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_script.py:1307, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1305 if maybe_already_compiled_fn:
1306 return maybe_already_compiled_fn
-> 1307 ast = get_jit_def(obj, obj.__name__)
1308 if _rcb is None:
1309 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:264, in get_jit_def(fn, def_name, self_name, is_classmethod)
261 qualname = get_qualified_name(fn)
262 pdt_arg_types = type_trace_db.get_args_types(qualname)
--> 264 return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:315, in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
310 type_comment_decl = torch._C.parse_type_comment(type_line)
311 decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
313 return Def(Ident(r, def_name),
314 decl,
--> 315 build_stmts(ctx, body))
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:137, in build_stmts(ctx, stmts)
136 def build_stmts(ctx, stmts):
--> 137 stmts = [build_stmt(ctx, s) for s in stmts]
138 return list(filter(None, stmts))
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:137, in <listcomp>(.0)
136 def build_stmts(ctx, stmts):
--> 137 stmts = [build_stmt(ctx, s) for s in stmts]
138 return list(filter(None, stmts))
File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:286, in Builder.__call__(self, ctx, node)
284 method = getattr(self, 'build_' + node.__class__.__name__, None)
285 if method is None:
--> 286 raise UnsupportedNodeError(ctx, node)
287 return method(ctx, node)
UnsupportedNodeError: import statements aren't supported:
File "/home/yunusemre.ozkose/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/onnx/__init__.py", line 386
"""
from torch.onnx import utils
~~~~ <--- HERE
return utils.is_in_onnx_export()
How can I solve this?
My env:
Versions of relevant libraries:
[pip3] k2==1.21.dev20221103+cuda11.3.torch1.10.0
[pip3] numpy==1.22.4
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0
[pip3] torchvision==0.11.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h9edb442_10 conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] k2 1.21.dev20221103+cuda11.3.torch1.10.0 pypi_0 pypi
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h95df7f1_0 conda-forge
[conda] mkl_fft 1.3.1 py38h8666266_1 conda-forge
[conda] mkl_random 1.2.2 py38h1abd341_0 conda-forge
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.10.0 py3.8_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 0.10.0 py38_cu113 pytorch
[conda] torchvision 0.11.0 py38_cu113 pytorch