Error when jit.trace with dictionary input

Is it possible to use dict as input to traced model, and the dict key-value has different value types?
For example, the following dummy example failed with an error

class A(th.nn.Module):
    def __init__(self):
         super().__init__()
         self.lin = th.nn.Linear(3,2)

    def forward(self, data):  # data is a dict
         return self.lin(data["x"])


data = {
    "x": th.rand(2,3), 
    "y": {
        "z": th.rand(5,3)    # "y" has a dict as value, but "x" has a tensor as value
    }
}

a = A()
a(data)   # works well

th.jit.trace(a, data)  # error happens

Error says

RuntimeError: Tracer cannot infer type of ({'x': tensor([[0.1657, 0.9904, 0.5861],
        [0.2403, 0.2352, 0.6989]]), 'y': {'z': tensor([[0.1907, 0.8666, 0.1840],
        [0.6437, 0.2888, 0.0081],
        [0.6681, 0.7088, 0.3869],
        [0.4880, 0.8989, 0.3832],
        [0.0746, 0.7623, 0.2115]])}},)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Dict[str, Tensor]

Based on this issue and workaround it seems nested dicts are not allowed and you might need to flatten/unflatten the input.

1 Like