Is it possible to use dict as input to traced model, and the dict key-value has different value types?
For example, the following dummy example failed with an error
class A(th.nn.Module):
def __init__(self):
super().__init__()
self.lin = th.nn.Linear(3,2)
def forward(self, data): # data is a dict
return self.lin(data["x"])
data = {
"x": th.rand(2,3),
"y": {
"z": th.rand(5,3) # "y" has a dict as value, but "x" has a tensor as value
}
}
a = A()
a(data) # works well
th.jit.trace(a, data) # error happens
Error says
RuntimeError: Tracer cannot infer type of ({'x': tensor([[0.1657, 0.9904, 0.5861],
[0.2403, 0.2352, 0.6989]]), 'y': {'z': tensor([[0.1907, 0.8666, 0.1840],
[0.6437, 0.2888, 0.0081],
[0.6681, 0.7088, 0.3869],
[0.4880, 0.8989, 0.3832],
[0.0746, 0.7623, 0.2115]])}},)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Dict[str, Tensor]