When debugging a model, it’s useful to add asserts and other similar logic to catch bugs. However, these asserts can be slow, so for production use cases, it’s useful to remove them all. Running with python -O will do that. However, when exporting a model, the asserts remain in the compiled graph even with the optimized flag turned on.
class DummyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(20, 20, 3)
def forward(self, x):
assert x.shape[1] == 20
return self.conv(x)
def export():
net = DummyNet()
jit_net = torch.jit.script(net)
print(jit_net.graph)
Produces this graph
graph(%self : __torch__.DummyNet,
%x.1 : Tensor):
%25 : str = prim::Constant[value="AssertionError: "]()
%4 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=20]()
%3 : int[] = aten::size(%x.1) # <string>:7:9
%5 : int = aten::__getitem__(%3, %4)
%7 : bool = aten::eq(%5, %6)
= prim::If(%7)
block0():
-> ()
block1():
= prim::RaiseException(%25)
-> ()
%13 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self)
%15 : Tensor = prim::CallMethod[name="forward"](%13, %x.1)
return (%15)