Difference in _get_trace_graph in version 1.7 and get_trace_graph in version 1.1

Help !!!
I know that the function torch.jit.get_trace_graph(model, args) in Pytorch 1.3.0 is changed into torch.jit._get_trace_graph(model, args) in Pytorch 1.7.0, but it seems different between this two functions.

So what should I do to get the same output as torch.jit.get_trace_graph(model, args) in Pytorch 1.3.0 while using torch.jit._get_trace_graph(model, args) in Pytorch 1.7.0 ???

Here comes some of my tries


Due to following warning message from torch/jit/_trace.py() def _get_trace_graph(), Line 1114 - 1142

    warning::
        This function is internal-only and should only be used by the ONNX
        exporter. If you are trying to get a graph through tracing, please go
        through the public API instead::

            trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
            trace_graph = trace.graph

    Trace a function or model, returning a tuple consisting of the both the
    *trace* of an execution, as well as the original return value. If return_inputs,
    also returns the trace inputs as part of the tuple

    Tracing is guaranteed not to change the semantics of the function/module
    that is traced.

    Arguments:
        f (torch.nn.Module or function): the function or module
            to be traced.
        args (tuple or Tensor): the positional arguments to pass to the
            function/module to be traced.  A non-tuple is assumed to
            be a single positional argument to be passed to the model.
        kwargs (dict): the keyword arguments to pass to the function/module
            to be traced.

    Example (trace a cell):

    .. testcode::

        trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))

so I try the following code

trace, _ = torch.jit.trace(model, args).graph

then it returns:

Traceback (most recent call last):
  File "tools/train.py", line 299, in <module>
    main()
  File "tools/train.py", line 124, in main
    writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 381, in graph
    trace, _ = torch.jit.trace(model, args).graph
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 745, in trace
    _module_class,
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 931, in trace_module
    module = make_module(mod, _module_class, _compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 991, in __init__
    assert isinstance(orig, torch.nn.Module)
AssertionError

I followed the advice from Missing ‘get_trace_graph’ function in Pytorch1.7 and then:


I try

trace, _ = torch.jit._get_trace_graph(model, args)

and than it returns:

Traceback (most recent call last):
  File "tools/train.py", line 299, in <module>
    main()
  File "tools/train.py", line 124, in main
    writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
  File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
  File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 387, in graph
    graph = trace.graph
AttributeError: 'torch._C.Graph' object has no attribute 'graph'

So I check the function difference between _get_trace_graph(model, args) and get_trace_graph(model, args)

% Pytorch 1.7.0
from torch/jit/_trace.py, _get_trace_graph(), Line 1111,

def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
                     return_inputs=False, _return_inputs_states=False):
    if kwargs is None:
        kwargs = {}
    if not isinstance(args, tuple):
        args = (args,)
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
    return outs

from torch/jit/_trace.py, class ONNXTracedModule(), Line 73, I get the following in Line 125:

graph, out = torch._C._create_graph_by_tracing(
            wrapper,
            in_vars + module_state,
            _create_interpreter_name_lookup_fn(),
            self.strict,
            self._force_outplace,
        )

and the variable graph seems a conv output, with following attributes:

['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_export_onnx', '_pretty_print_onnx', 'addInput', 'appendNode', 'at', 'constant', 'copy', 'create', 'createClone', 'createCudaFusionGroup', 'createFusionGroup', 'dump_alias_db', 'eraseInput', 'findAllNodes', 'findNode', 'inputs', 'insertConstant', 'insertNode', 'lint', 'nodes', 'op', 'outputs', 'param_node', 'prependNode', 'registerOutput', 'return_node', 'str']

% Pytorch 1.1.0
from torch/jit/init.py, get_trace_graph(), Line 171.

def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False):
    if kwargs is None:
        kwargs = {}
    if not isinstance(args, tuple):
        args = (args,)
    return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)

from torch/jit/init.py, class LegacyTracedModule(), Line 233, I get the following in Line 247:

trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))

and the variable trace shows <TracingState 0x5648e988f880> , with following attributes:

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'graph', 'pop_scope', 'push_scope', 'set_graph']

so what should I do to get this two output same?

I do really appreicate any advice!