Exporting function with size value as args using executorch

Hi all, I’m working on exporting a model containing a function with 3 args indicating the size of output tensors. The model looks something like the following:

class Module(nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
        return torch.full((x.item(), y.item(), z.item()), 1)

# export + to_edge_transform_and_lower, the exported program is called edge_program...

edge_program.to_executorch()

The .to_executorch() call on the last row triggers the following error:

Traceback (most recent call last):
  File "/home/<whoami>/workspace/convert_litert/pte_tryout.py", line 200, in <module>
    executorch = multi_entry_points_edge_program.to_executorch()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 93, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/program/_program.py", line 1364, in to_executorch
    new_gm_res = p(new_gm)
                 ^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/passes/infra/pass_base.py", line 44, in __call__
    res = self.call(graph_module)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 576, in call
    result = self.call_submodule(graph_module, tuple(inputs))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 662, in call_submodule
    res = super().call_submodule(graph_module, inputs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 539, in call_submodule
    interpreter.run(*inputs_data)
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 379, in run_node
    return super().run_node(n)
           ^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/interpreter.py", line 230, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/pass_base.py", line 611, in call_function
    return self.callback.call_operator(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/spec_prop_pass.py", line 96, in call_operator
    meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 991, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 830, in unflatten
    leaves = list(leaves)
             ^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/passes/spec_prop_pass.py", line 23, in make_spec
    return TensorSpec.from_tensor(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 173, in from_tensor
    spec = cls(
           ^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 144, in __init__
    self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/executorch/exir/tensor.py", line 71, in dim_order_from_stride
    if s == 0:
       ^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/__init__.py", line 740, in __bool__
    return self.node.bool_()
           ^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 574, in bool_
    return self.guard_bool("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 512, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "/disk1/python311-executorch/lib64/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u19*u20, 0) (unhinted: Eq(u19*u20, 0)).  (Size-like symbols: u20, u19)

After browsing the Dynamic shapes manual and Dealing with GuardOnDataDependentSymNode errors handbook, I realized that this error is caused by that the data-dependent shape (x.item(), y.item(), z.item()) is bad for ExecuTorch exporting because there is no value constraint on them.

So I guess one solution could be transforming the exported graph before the .to_executorch() call to add some constraint to the Nodes derived from the three args. But I don’t know whether I’m on the right track, or is there a better way of resolving this?

1 Like

I’ve found a solution to this error. The following data-dependent error message:

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u19*u20, 0) (unhinted: Eq(u19*u20, 0)).  (Size-like symbols: u20, u19)

indicates that the exported program will fail when Eq(u19*u20, 0) (i.e. the multiplication of u19 (the IR node of x) and u20 (the IR node of y) equals zero). And I need to explicitly ensure this will not happen.

Therefore, the solution to this is to add a torch._check_value as a constraint hint in the Module:

class Module(nn.Module):
    def forward(self, x, y, z):
        ...
        torch._check(x.item() * y.item() != 0)
        ...

After adding this assertion, one needs to add the corresponding torch._check to other calculations of the values if other guards are triggered.

Additional hint to identifying binding between the IR nodes (e.g. u19) and the variables (e.g. x):

After setting

export TORCH_LOGS=dynamo

, the torch dynamo could output addititonal debugging hints. These include a suggestion of setting TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u19,u20". By setting this environment variable, we can see additional operations done when creating these two nodes. Now if we have some previous guards set in our code by torch._check_value, this will be reflected in the debug info.

For example, we have in python code:

torch._check_value(x < 100)

, there will be something like

I0211 16:51:43.327000 1024007 torch/fx/experimental/symbolic_shapes.py:6281] runtime_assert u28 <= 100 [guard added] (_ops.py:723 in __call__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u28 <= 100"

in the debugging info. Then we can ensure that the IR node of x is u28.