Say I want to use FX to generate a GraphModule
that should, as its first operation, run an assertion on the shape of its input argument. It is not possible torch.fx.wrap
assert
, since assert
is not a function. Implementing it as an if
is also not supported, since it counts as control flow.
I can implement an asserter
function and torch.fx.wrap
it, but that seems quite ugly and gives bad tracebacks:
In [18]: def asserter(b: bool):
...: assert b
...:
In [19]: def f(x):
...: asserter(x.shape[0] == 5)
...: return x.sum()
...:
In [20]: fx.wrap(asserter)
Out[20]: <function __main__.asserter(b: bool)>
In [21]: g = fx.symbolic_trace(f)
In [22]: print(g.code)
import __main__
def forward(self, x):
getattr_1 = x.shape
getitem = getattr_1[0]; getattr_1 = None
eq_1 = getitem == 5; getitem = None
asserter = __main__.asserter(eq_1); eq_1 = None
sum_1 = x.sum(); x = None
return sum_1
In [23]: s = torch.jit.script(g)
In [24]: s(torch.ones(7))
Error: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "<ipython-input-18-375f83c1af1b>", line 2, in forward
def asserter(b: bool):
assert b
~~~~~~~~ <--- HERE
RuntimeError: AssertionError:
In [25]: s(torch.ones(5))
Out[25]: tensor(5.)
In [26]: s.graph
Out[26]:
graph(%self : __torch__.torch.fx.graph_module.___torch_mangle_0.f,
%x.1 : Tensor):
%17 : None = prim::Constant()
%13 : Function = prim::Constant[name="asserter"]()
%5 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=5]()
%getattr_1.1 : int[] = aten::size(%x.1) # <string>:7:9
%getitem.1 : int = aten::__getitem__(%getattr_1.1, %5)
%eq_1.1 : bool = aten::eq(%getitem.1, %9)
%asserter : None = prim::CallFunction(%13, %eq_1.1)
%sum_1.1 : Tensor = aten::sum(%x.1, %17)
return (%sum_1.1)
As you can see in the final output, it also complicates the compiled TorchScript significantly — using assert directly generates code without a Python function call:
In [29]: def f2(x):
...: assert x.shape[0] == 5
...: return x.sum()
...:
In [30]: s2 = torch.jit.script(f2)
In [31]: print(s2.graph)
graph(%x.1 : Tensor):
%13 : None = prim::Constant()
%24 : str = prim::Constant[value="AssertionError: "]()
%3 : int = prim::Constant[value=0]() # <ipython-input-29-fc36d8843bb5>:2:19
%5 : int = prim::Constant[value=5]() # <ipython-input-29-fc36d8843bb5>:2:25
%2 : int[] = aten::size(%x.1) # <string>:7:9
%4 : int = aten::__getitem__(%2, %3) # <ipython-input-29-fc36d8843bb5>:2:11
%6 : bool = aten::eq(%4, %5) # <ipython-input-29-fc36d8843bb5>:2:11
= prim::If(%6) # <ipython-input-29-fc36d8843bb5>:2:4
block0():
-> ()
block1():
= prim::RaiseException(%24) # <ipython-input-29-fc36d8843bb5>:2:4
-> ()
%14 : Tensor = aten::sum(%x.1, %13) # <ipython-input-29-fc36d8843bb5>:3:11
return (%14)
Is there a better way to go about this?
Thanks!
(This is in some sense a subset of the discussion on Customizing generated code? - #3 by Linux-cpp-lisp)