Exporting onnx and jit model at the same time

Hi, I am trying to export a torch model to onnx and jit at the same time for a library (Icefall). Some parts of models have to be changed to export. I am seperating these parts with torch.jit.is_scripting() and torch.onnx.is_in_onnx_export(). I can export onnx model, but cannot export jit model. A simple example:

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        if torch.jit.is_scripting():
            x = self.fc1(x)
    
        if torch.onnx.is_in_onnx_export():
            x = self.fc2(x)
    
        return x

net = Net()
sample = torch.randn(1, 10)

torch.onnx.export(
    net,
    (sample),
    "model.onnx",
    verbose=False,
    opset_version=13,
    input_names=["x"],
    output_names=["out"],
)

net = torch.jit.script(net)
net.save(str("model.pt"))

Onnx export is done successfully, but jit scripting raise error:

UnsupportedNodeError                      Traceback (most recent call last)
Cell In [9], line 1
----> 1 net = torch.jit.script(net)
      2 net.save(str("model.pt"))

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_script.py:1257, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1255 if isinstance(obj, torch.nn.Module):
   1256     obj = call_prepare_scriptable_func(obj)
-> 1257     return torch.jit._recursive.create_script_module(
   1258         obj, torch.jit._recursive.infer_methods_to_compile
   1259     )
   1261 if isinstance(obj, dict):
   1262     return create_script_dict(obj)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:451, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    449 if not is_tracing:
    450     AttributeTypeIsSupportedChecker().check(nn_module)
--> 451 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:517, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    515 # Compile methods if necessary
    516 if concrete_type not in concrete_type_store.methods_compiled:
--> 517     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    518     # Create hooks after methods to ensure no name collisions between hooks and methods.
    519     # If done before, hooks can overshadow methods that aren't exported.
    520     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:368, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    365 property_defs = [p.def_ for p in property_stubs]
    366 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 368 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_recursive.py:838, in try_compile_fn(fn, loc)
    834 # We don't have the actual scope where the function was defined, but we can
    835 # extract the necessary info from the closed over variables on the function
    836 # object
    837 rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
--> 838 return torch.jit.script(fn, _rcb=rcb)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/_script.py:1307, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1305 if maybe_already_compiled_fn:
   1306     return maybe_already_compiled_fn
-> 1307 ast = get_jit_def(obj, obj.__name__)
   1308 if _rcb is None:
   1309     _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:264, in get_jit_def(fn, def_name, self_name, is_classmethod)
    261     qualname = get_qualified_name(fn)
    262     pdt_arg_types = type_trace_db.get_args_types(qualname)
--> 264 return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:315, in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
    310     type_comment_decl = torch._C.parse_type_comment(type_line)
    311     decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
    313 return Def(Ident(r, def_name),
    314            decl,
--> 315            build_stmts(ctx, body))

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:137, in build_stmts(ctx, stmts)
    136 def build_stmts(ctx, stmts):
--> 137     stmts = [build_stmt(ctx, s) for s in stmts]
    138     return list(filter(None, stmts))

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:137, in <listcomp>(.0)
    136 def build_stmts(ctx, stmts):
--> 137     stmts = [build_stmt(ctx, s) for s in stmts]
    138     return list(filter(None, stmts))

File ~/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/jit/frontend.py:286, in Builder.__call__(self, ctx, node)
    284 method = getattr(self, 'build_' + node.__class__.__name__, None)
    285 if method is None:
--> 286     raise UnsupportedNodeError(ctx, node)
    287 return method(ctx, node)

UnsupportedNodeError: import statements aren't supported:
  File "/home/yunusemre.ozkose/anaconda3/envs/k2sh/lib/python3.8/site-packages/torch/onnx/__init__.py", line 386
    """

    from torch.onnx import utils
    ~~~~ <--- HERE
    return utils.is_in_onnx_export()

How can I solve this?

My env:

Versions of relevant libraries:
[pip3] k2==1.21.dev20221103+cuda11.3.torch1.10.0
[pip3] numpy==1.22.4
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0
[pip3] torchvision==0.11.0
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.3.1              h9edb442_10    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] k2                        1.21.dev20221103+cuda11.3.torch1.10.0          pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py38h95df7f1_0    conda-forge
[conda] mkl_fft                   1.3.1            py38h8666266_1    conda-forge
[conda] mkl_random                1.2.2            py38h1abd341_0    conda-forge
[conda] numpy                     1.22.4                   pypi_0    pypi
[conda] pytorch                   1.10.0          py3.8_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.10.0               py38_cu113    pytorch
[conda] torchvision               0.11.0               py38_cu113    pytorch