Hi all, I’m working on exporting a model containing a function with 3 args indicating the size of output tensors. The model looks something like the following:
class Module(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
return torch.full((x.item(), y.item(), z.item()), 1)
# export + to_edge_transform_and_lower, the exported program is called edge_program...
edge_program.to_executorch()
The .to_executorch()
call on the last row triggers the following error:
Traceback (most recent call last):
File "/home/<whoami>/workspace/convert_litert/pte_tryout.py", line 200, in <module>
executorch = multi_entry_points_edge_program.to_executorch()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 93, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 1364, in to_executorch
new_gm_res = p(new_gm)
^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/passes/infra/pass_base.py", line 44, in __call__
res = self.call(graph_module)
^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 576, in call
result = self.call_submodule(graph_module, tuple(inputs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 662, in call_submodule
res = super().call_submodule(graph_module, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 539, in call_submodule
interpreter.run(*inputs_data)
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 379, in run_node
return super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/interpreter.py", line 230, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 611, in call_function
return self.callback.call_operator(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/spec_prop_pass.py", line 96, in call_operator
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 991, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 830, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/spec_prop_pass.py", line 23, in make_spec
return TensorSpec.from_tensor(x)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 173, in from_tensor
spec = cls(
^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 144, in __init__
self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 71, in dim_order_from_stride
if s == 0:
^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/__init__.py", line 740, in __bool__
return self.node.bool_()
^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 574, in bool_
return self.guard_bool("", 0)
^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 512, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u19*u20, 0) (unhinted: Eq(u19*u20, 0)). (Size-like symbols: u20, u19)
After browsing the Dynamic shapes manual and Dealing with GuardOnDataDependentSymNode errors handbook, I realized that this error is caused by that the data-dependent shape (x.item(), y.item(), z.item())
is bad for ExecuTorch exporting because there is no value constraint on them.
So I guess one solution could be transforming the exported graph before the .to_executorch()
call to add some constraint to the Nodes derived from the three args. But I don’t know whether I’m on the right track, or is there a better way of resolving this?