Error: torch.jit.trace with example_kwarg_inputs and nested dictionaries

Hi,

First, I know there is an issue with nested dictionaries when using “example_inputs”, but I was hoping it is not the same with “example_kwarg_imputs”.

Below are three toy examples for reproducing the error (and a “good” example").
Bottom line - how can I trace a nested dictionary with “example_kwarg_imputs”?

Example 1 (works)

Code

import torch

def func(a, b):
    return a + b

x = {"a": torch.tensor(1), "b": torch.tensor(2)}

ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)

eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)

Output

graph(%a : Long(requires_grad=0, device=cpu),
      %b : Long(requires_grad=0, device=cpu)):
  %2 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
  %3 : Long(requires_grad=0, device=cpu) = aten::add(%a, %b, %2) # <workdir>/test.py:4:0
  return (%3)

graph(%a : Long(requires_grad=0, device=cpu),
      %b : Long(requires_grad=0, device=cpu)):
  %2 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
  %3 : Long(requires_grad=0, device=cpu) = aten::add(%a, %b, %2) # <workdir>/test.py:4:0
  return (%3)

Example 2 (Error)

Code

import torch

def func(a, b):
    return a["a2"] + b

x = {"a": {"a2": torch.tensor(1)}, "b": torch.tensor(2)}

ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)

eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)

Output

graph(%a : Dict(str, Tensor),
      %b : Long(requires_grad=0, device=cpu)):
  %1 : str = prim::Constant[value="a2"]()
  %2 : Tensor = aten::__getitem__(%a, %1) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
  %4 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
  %5 : Long(requires_grad=0, device=cpu) = aten::add(%2, %b, %4) # <workdir>/test.py:4:0
  return (%5)

Traceback (most recent call last):
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 459, in run_mod_and_filter_tensor_outputs
    outs = wrap_retval(mod(**inputs))
RuntimeError: func() expected at most 1 argument(s) but received 2 argument(s). Declaration: func(Tensor a) -> Tensor

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<workdir>/test.py", line 11, in <module>
    eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 884, in trace
    _check_trace(
  File "<conda_env>/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 552, in _check_trace
    traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 467, in run_mod_and_filter_tensor_outputs
    raise TracingCheckError(
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
encountered an exception while running the trace with test inputs.
Exception:
        func() expected at most 1 argument(s) but received 2 argument(s). Declaration: func(Tensor a) -> Tensor

Example 3 (Error)

Code

import torch

def func(a, b):
    return a["a2"] + b["b2"]

x = {"a": {"a2": torch.tensor(1)}, "b": {"b2": torch.tensor(2)}}

ei_traced = torch.jit.trace(func, example_inputs=(x["a"], x["b"]))
print(ei_traced.graph)

eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
print(eki_traced.graph)

Output

graph(%a : Dict(str, Tensor),
      %b : Dict(str, Tensor)):
  %1 : str = prim::Constant[value="a2"]()
  %2 : Tensor = aten::__getitem__(%a, %1) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
  %4 : str = prim::Constant[value="b2"]()
  %5 : Tensor = aten::__getitem__(%b, %4) # <conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py:859:0
  %6 : int = prim::Constant[value=1]() # <workdir>/test.py:4:0
  %7 : Long(requires_grad=0, device=cpu) = aten::add(%2, %5, %6) # <workdir>/test.py:4:0
  return (%7)

Traceback (most recent call last):
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 459, in run_mod_and_filter_tensor_outputs
    outs = wrap_retval(mod(**inputs))
RuntimeError: func() expected at most 0 argument(s) but received 2 argument(s). Declaration: func() -> Tensor

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<workdir>/test.py", line 11, in <module>
    eki_traced = torch.jit.trace(func, example_kwarg_inputs=x)
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 884, in trace
    _check_trace(
  File "<conda_env>/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 552, in _check_trace
    traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
  File "<conda_env>/lib/python3.8/site-packages/torch/jit/_trace.py", line 467, in run_mod_and_filter_tensor_outputs
    raise TracingCheckError(
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
encountered an exception while running the trace with test inputs.
Exception:
        func() expected at most 0 argument(s) but received 2 argument(s). Declaration: func() -> Tensor