Help !!!
I know that the function torch.jit.get_trace_graph(model, args)
in Pytorch 1.3.0 is changed into torch.jit._get_trace_graph(model, args)
in Pytorch 1.7.0, but it seems different between this two functions.
So what should I do to get the same output as torch.jit.get_trace_graph(model, args)
in Pytorch 1.3.0 while using torch.jit._get_trace_graph(model, args)
in Pytorch 1.7.0 ???
Here comes some of my tries
Due to following warning message from torch/jit/_trace.py() def _get_trace_graph(), Line 1114 - 1142
warning::
This function is internal-only and should only be used by the ONNX
exporter. If you are trying to get a graph through tracing, please go
through the public API instead::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
trace_graph = trace.graph
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value. If return_inputs,
also returns the trace inputs as part of the tuple
Tracing is guaranteed not to change the semantics of the function/module
that is traced.
Arguments:
f (torch.nn.Module or function): the function or module
to be traced.
args (tuple or Tensor): the positional arguments to pass to the
function/module to be traced. A non-tuple is assumed to
be a single positional argument to be passed to the model.
kwargs (dict): the keyword arguments to pass to the function/module
to be traced.
Example (trace a cell):
.. testcode::
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
so I try the following code
trace, _ = torch.jit.trace(model, args).graph
then it returns:
Traceback (most recent call last):
File "tools/train.py", line 299, in <module>
main()
File "tools/train.py", line 124, in main
writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 381, in graph
trace, _ = torch.jit.trace(model, args).graph
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 745, in trace
_module_class,
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 931, in trace_module
module = make_module(mod, _module_class, _compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
submodule, TracedModule, _compilation_unit=None
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
submodule, TracedModule, _compilation_unit=None
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
submodule, TracedModule, _compilation_unit=None
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
submodule, TracedModule, _compilation_unit=None
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 1043, in __init__
submodule, TracedModule, _compilation_unit=None
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 563, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/home/ivc/anaconda3/envs/PYt/lib/python3.6/site-packages/torch/jit/_trace.py", line 991, in __init__
assert isinstance(orig, torch.nn.Module)
AssertionError
I followed the advice from Missing ‘get_trace_graph’ function in Pytorch1.7 and then:
I try
trace, _ = torch.jit._get_trace_graph(model, args)
and than it returns:
Traceback (most recent call last):
File "tools/train.py", line 299, in <module>
main()
File "tools/train.py", line 124, in main
writer_dict['writer'].add_graph_deprecated(sor_model, (dump_input, ))
File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/writer.py", line 825, in add_graph_deprecated
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
File "/home/ZhuoweiXu/anaconda3/envs/HRNet/lib/python3.6/site-packages/tensorboardX/pytorch_graph.py", line 387, in graph
graph = trace.graph
AttributeError: 'torch._C.Graph' object has no attribute 'graph'
So I check the function difference between _get_trace_graph(model, args)
and get_trace_graph(model, args)
% Pytorch 1.7.0
from torch/jit/_trace.py, _get_trace_graph(), Line 1111,
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
return_inputs=False, _return_inputs_states=False):
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
return outs
from torch/jit/_trace.py, class ONNXTracedModule(), Line 73, I get the following in Line 125:
graph, out = torch._C._create_graph_by_tracing(
wrapper,
in_vars + module_state,
_create_interpreter_name_lookup_fn(),
self.strict,
self._force_outplace,
)
and the variable graph
seems a conv output, with following attributes:
['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_export_onnx', '_pretty_print_onnx', 'addInput', 'appendNode', 'at', 'constant', 'copy', 'create', 'createClone', 'createCudaFusionGroup', 'createFusionGroup', 'dump_alias_db', 'eraseInput', 'findAllNodes', 'findNode', 'inputs', 'insertConstant', 'insertNode', 'lint', 'nodes', 'op', 'outputs', 'param_node', 'prependNode', 'registerOutput', 'return_node', 'str']
% Pytorch 1.1.0
from torch/jit/init.py, get_trace_graph(), Line 171.
def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False):
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
from torch/jit/init.py, class LegacyTracedModule(), Line 233, I get the following in Line 247:
trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))
and the variable trace
shows <TracingState 0x5648e988f880>
, with following attributes:
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'graph', 'pop_scope', 'push_scope', 'set_graph']
so what should I do to get this two output same?
I do really appreicate any advice!