Hi,
First, I know there is an issue with nested dictionaries when using “example_inputs”, but I was hoping it is not the same with “example_kwarg_imputs”.
Below are three toy examples for reproducing the error (and a “good” example").
Bottom line - how can I trace a nested dictionary with “example_kwarg_imputs”?
Example 1 (works)
Code
import torch
def func(a, b):
return a + b
x = {"a": torch.tensor(1), "b": torch.tensor(2)}
ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)
eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)
Output
graph(%a : Long(requires_grad=0, device=cpu),
%b : Long(requires_grad=0, device=cpu)):
%2 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
%3 : Long(requires_grad=0, device=cpu) = aten::add(%a, %b, %2) # <workdir>/test.py:4:0
return (%3)
graph(%a : Long(requires_grad=0, device=cpu),
%b : Long(requires_grad=0, device=cpu)):
%2 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
%3 : Long(requires_grad=0, device=cpu) = aten::add(%a, %b, %2) # <workdir>/test.py:4:0
return (%3)
Example 2 (Error)
Code
import torch
def func(a, b):
return a["a2"] + b
x = {"a": {"a2": torch.tensor(1)}, "b": torch.tensor(2)}
ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)
eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)
Output
graph(%a : Dict(str, Tensor),
%b : Long(requires_grad=0, device=cpu)):
%1 : str = prim::Constant[value="a2"]()
%2 : Tensor = aten::__getitem__(%a, %1) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
%4 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
%5 : Long(requires_grad=0, device=cpu) = aten::add(%2, %b, %4) # <workdir>/test.py:4:0
return (%5)
Traceback (most recent call last):
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 459, in run_mod_and_filter_tensor_outputs
outs = wrap_retval(mod(**inputs))
RuntimeError: func() expected at most 1 argument(s) but received 2 argument(s). Declaration: func(Tensor a) -> Tensor
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<workdir>/test.py", line 11, in <module>
eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 884, in trace
_check_trace(
File "<conda_env>/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 552, in _check_trace
traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 467, in run_mod_and_filter_tensor_outputs
raise TracingCheckError(
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
encountered an exception while running the trace with test inputs.
Exception:
func() expected at most 1 argument(s) but received 2 argument(s). Declaration: func(Tensor a) -> Tensor
Example 3 (Error)
Code
import torch
def func(a, b):
return a["a2"] + b["b2"]
x = {"a": {"a2": torch.tensor(1)}, "b": {"b2": torch.tensor(2)}}
ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)
eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)
Output
graph(%a : Dict(str, Tensor),
%b : Dict(str, Tensor)):
%1 : str = prim::Constant[value="a2"]()
%2 : Tensor = aten::__getitem__(%a, %1) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
%4 : str = prim::Constant[value="b2"]()
%5 : Tensor = aten::__getitem__(%b, %4) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
%6 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
%7 : Long(requires_grad=0, device=cpu) = aten::add(%2, %5, %6) # <workdir>/test.py:4:0
return (%7)
Traceback (most recent call last):
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 459, in run_mod_and_filter_tensor_outputs
outs = wrap_retval(mod(**inputs))
RuntimeError: func() expected at most 0 argument(s) but received 2 argument(s). Declaration: func() -> Tensor
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<workdir>/test.py", line 11, in <module>
eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 884, in trace
_check_trace(
File "<conda_env>/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 552, in _check_trace
traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 467, in run_mod_and_filter_tensor_outputs
raise TracingCheckError(
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
encountered an exception while running the trace with test inputs.
Exception:
func() expected at most 0 argument(s) but received 2 argument(s). Declaration: func() -> Tensor