Generating error checking and basic flow control

Say I want to use FX to generate a GraphModule that should, as its first operation, run an assertion on the shape of its input argument. It is not possible torch.fx.wrap assert, since assert is not a function. Implementing it as an if is also not supported, since it counts as control flow.

I can implement an asserter function and torch.fx.wrap it, but that seems quite ugly and gives bad tracebacks:

In [18]: def asserter(b: bool):
    ...:     assert b
    ...: 

In [19]: def f(x):
    ...:     asserter(x.shape[0] == 5)
    ...:     return x.sum()
    ...: 

In [20]: fx.wrap(asserter)
Out[20]: <function __main__.asserter(b: bool)>

In [21]: g = fx.symbolic_trace(f)

In [22]: print(g.code)
import __main__
def forward(self, x):
    getattr_1 = x.shape
    getitem = getattr_1[0];  getattr_1 = None
    eq_1 = getitem == 5;  getitem = None
    asserter = __main__.asserter(eq_1);  eq_1 = None
    sum_1 = x.sum();  x = None
    return sum_1
    

In [23]: s = torch.jit.script(g)

In [24]: s(torch.ones(7))

Error: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<ipython-input-18-375f83c1af1b>", line 2, in forward
def asserter(b: bool):
    assert b
    ~~~~~~~~ <--- HERE
RuntimeError: AssertionError: 


In [25]: s(torch.ones(5))
Out[25]: tensor(5.)

In [26]: s.graph
Out[26]: 
graph(%self : __torch__.torch.fx.graph_module.___torch_mangle_0.f,
      %x.1 : Tensor):
  %17 : None = prim::Constant()
  %13 : Function = prim::Constant[name="asserter"]()
  %5 : int = prim::Constant[value=0]()
  %9 : int = prim::Constant[value=5]()
  %getattr_1.1 : int[] = aten::size(%x.1) # <string>:7:9
  %getitem.1 : int = aten::__getitem__(%getattr_1.1, %5)
  %eq_1.1 : bool = aten::eq(%getitem.1, %9)
  %asserter : None = prim::CallFunction(%13, %eq_1.1)
  %sum_1.1 : Tensor = aten::sum(%x.1, %17)
  return (%sum_1.1)

As you can see in the final output, it also complicates the compiled TorchScript significantly — using assert directly generates code without a Python function call:

In [29]: def f2(x):
    ...:     assert x.shape[0] == 5
    ...:     return x.sum()
    ...: 

In [30]: s2 = torch.jit.script(f2)

In [31]: print(s2.graph)
graph(%x.1 : Tensor):
  %13 : None = prim::Constant()
  %24 : str = prim::Constant[value="AssertionError: "]()
  %3 : int = prim::Constant[value=0]() # <ipython-input-29-fc36d8843bb5>:2:19
  %5 : int = prim::Constant[value=5]() # <ipython-input-29-fc36d8843bb5>:2:25
  %2 : int[] = aten::size(%x.1) # <string>:7:9
  %4 : int = aten::__getitem__(%2, %3) # <ipython-input-29-fc36d8843bb5>:2:11
  %6 : bool = aten::eq(%4, %5) # <ipython-input-29-fc36d8843bb5>:2:11
   = prim::If(%6) # <ipython-input-29-fc36d8843bb5>:2:4
    block0():
      -> ()
    block1():
       = prim::RaiseException(%24) # <ipython-input-29-fc36d8843bb5>:2:4
      -> ()
  %14 : Tensor = aten::sum(%x.1, %13) # <ipython-input-29-fc36d8843bb5>:3:11
  return (%14)

Is there a better way to go about this?

Thanks!

(This is in some sense a subset of the discussion on Customizing generated code? - #3 by Linux-cpp-lisp)

torch._assert() is traceable, you should be able to use instead of builtin assert. Additionally, for modules where you may not want/be able to modify the source to use torch._assert(), there is an AST Rewriter tracer in FX experimental, see here.

Hi @jfix,

Interesting — the AST rewriter looks like a great addition!

torch._assert looks like a good interim solution. The only issue remaining with it is that it still makes the backtrace useless:

In [4]: def f(x):
   ...:     torch._assert(x.shape[0] == 5, "invalid shape")
   ...:     return x*2.
   ...: 
   ...: 

In [5]: s = torch.jit.script(f)

In [6]: s(torch.ones(4))
---------------------------------------------------------------------------
Error                                     Traceback (most recent call last)
<ipython-input-6-62f4f187df9a> in <module>
----> 1 s(torch.ones(4))

Error: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 8, in f
  return out_size
def _assert(condition : bool, message : str):
  assert condition, message
  ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: AssertionError: invalid shape


In [7]: s(torch.ones(5))
Out[7]: tensor([2., 2., 2., 2., 2.])

This is a minor issue, though, and can be mitigated by using an assertion message that says where torch._assert was called.