Difference in _get_trace_graph in version 1.7 and get_trace_graph in version 1.1

Help !!!
I know that the function torch.jit.get_trace_graph(model, args) in Pytorch 1.3.0 is changed into torch.jit._get_trace_graph(model, args) in Pytorch 1.7.0, but it seems different between this two functions.

So what should I do to get the same output as torch.jit.get_trace_graph(model, args) in Pytorch 1.3.0 while using torch.jit._get_trace_graph(model, args) in Pytorch 1.7.0 ???

Here comes some of my tries


Due to following warning message from torch/jit/_trace.py() def _get_trace_graph(), Line 1114 - 1142

    warning::
        This function is internal-only and should only be used by the ONNX
        exporter. If you are trying to get a graph through tracing, please go
        through the public API instead::

            trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
            trace_graph = trace.graph

    Trace a function or model, returning a tuple consisting of the both the
    *trace* of an execution, as well as the original return value. If return_inputs,
    also returns the trace inputs as part of the tuple

    Tracing is guaranteed not to change the semantics of the function/module
    that is traced.

    Arguments:
        f (torch.nn.Module or function): the function or module
            to be traced.
        args (tuple or Tensor): the positional arguments to pass to the
            function/module to be traced.  A non-tuple is assumed to
            be a single positional argument to be passed to the model.
        kwargs (dict): the keyword arguments to pass to the function/module
            to be traced.

    Example (trace a cell):

    .. testcode::

        trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))

so I try the following code

trace, _ = torch.jit.trace(model, args).graph

then it returns:

Traceback (most recent call last):
  File "tools/train.py", line 299, in <module>
    main()
  File "tools/train.py", line 124, in main
    writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 381, in graph
    trace, _ = torch.jit.trace(model, args).graph
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 745, in trace
    _module_class,
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 931, in trace_module
    module = make_module(mod, _module_class, _compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
    submodule, TracedModule, _compilation_unit=None
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 991, in __init__
    assert isinstance(orig, torch.nn.Module)
AssertionError

I followed the advice from Missing ‘get_trace_graph’ function in Pytorch1.7 and then:


I try

trace, _ = torch.jit._get_trace_graph(model, args)

and than it returns:

Traceback (most recent call last):
  File "tools/train.py", line 299, in <module>
    main()
  File "tools/train.py", line 124, in main
    writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
  File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
  File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 387, in graph
    graph = trace.graph
AttributeError: 'torch._C.Graph' object has no attribute 'graph'

So I check the function difference between _get_trace_graph(model, args) and get_trace_graph(model, args)

% Pytorch 1.7.0
from torch/jit/_trace.py, _get_trace_graph(), Line 1111,

def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
                     return_inputs=False, _return_inputs_states=False):
    if kwargs is None:
        kwargs = {}
    if not isinstance(args, tuple):
        args = (args,)
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
    return outs

from torch/jit/_trace.py, class ONNXTracedModule(), Line 73, I get the following in Line 125:

graph, out = torch._C._create_graph_by_tracing(
            wrapper,
            in_vars + module_state,
            _create_interpreter_name_lookup_fn(),
            self.strict,
            self._force_outplace,
        )

and the variable graph seems a conv output, with following attributes:

['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_export_onnx', '_pretty_print_onnx', 'addInput', 'appendNode', 'at', 'constant', 'copy', 'create', 'createClone', 'createCudaFusionGroup', 'createFusionGroup', 'dump_alias_db', 'eraseInput', 'findAllNodes', 'findNode', 'inputs', 'insertConstant', 'insertNode', 'lint', 'nodes', 'op', 'outputs', 'param_node', 'prependNode', 'registerOutput', 'return_node', 'str']

% Pytorch 1.1.0
from torch/jit/init.py, get_trace_graph(), Line 171.

def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False):
    if kwargs is None:
        kwargs = {}
    if not isinstance(args, tuple):
        args = (args,)
    return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)

from torch/jit/init.py, class LegacyTracedModule(), Line 233, I get the following in Line 247:

trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))

and the variable trace shows <TracingState 0x5648e988f880> , with following attributes:

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'graph', 'pop_scope', 'push_scope', 'set_graph']

so what should I do to get this two output same?

I do really appreicate any advice!

AttributeError: ‘torch._C.Graph’ object has no attribute ‘graph’
This Error depicts that the trace output has no attribute graph, so, when you want to get the attribute like node.
You need to change the code from trace.graph().nodes() to trace.nodes().
It’s clear that just remove the attribute graph(), then you can access any attribute directly.
it means that the output trace is a graph, you don’t need to access the graph() attribute first and then access other attribute as you do in the older version.