How to get the python bytecode that causes the graph break?

Hi, everyone, In dynamo, it will break graph by some unsupport python operation, such as if else. I want to merge the breaked subgraphs to one. For this example, graph will be breaked to two subgraph.

def forward(self, b, c):
        if not torch.nn.functional.silu(b[0][0]):
            return torch.mm(b, c)
        else:
            return torch.add(b, c)
opcode         name       target                             args          kwargs
-------------  ---------  ---------------------------------  ------------  --------
placeholder    l_b_       L_b_                               ()            {}
call_function  getitem    <built-in function getitem>        (l_b_, 0)     {}
call_function  getitem_1  <built-in function getitem>        (getitem, 0)  {}
call_function  silu       <function silu at 0x7f9c4beff920>  (getitem_1,)  {}
output         output     output                             ((silu,),)    {}
opcode         name    target                                                  args          kwargs
-------------  ------  ------------------------------------------------------  ------------  --------
placeholder    l_b_    L_b_                                                    ()            {}
placeholder    l_c_    L_c_                                                    ()            {}
call_function  add     <built-in method add of type object at 0x7f9cc57b0aa0>  (l_b_, l_c_)  {}
output         output  output                                                  ((add,),)     {}

I want to merge two subgraphs to one graph. Is there a way to get the python bytecode that causes the graph break? Or something better to merge the graphs. Thank you for your help!

You can diagnose graph breaks with torch._logging — PyTorch main documentation

TORCH_LOGS="graph_breaks" python your_code.py

I have added torch._logging.set_logs(graph_breaks=True). But it doesn’t seem to work and there is no log in my command line.

torch._dynamo.reset()
torch._logging.set_logs(graph_breaks=True)
class TestModule(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = torch.nn.Conv2d(3, 255, (5, 5), 3, bias=False)

    def forward(self, b, c):
        if not torch.nn.functional.silu(b[0][0]):
            return torch.mm(b, c)
        else:
            return torch.add(b, c)

model = TestModule()
model_opt = torch.compile(model)
print(model_opt(torch.randn((1024, 1024)), torch.randn((1024, 1024))))