Question on CustomFuseGraph (graph_fuser). pytorch/tvm , pytorch/glow

@tom: Awesome. Thanks for taking the time and responding. I think that answers my question to a great extent. So, is the plan to upstream these changes to the CustomFuseGraph? Also, the torch/tvm’s custom fusion seems to ignore the control flow ops for now seen in fusion_pass.cpp.

I was also experimenting a bit and see that fusing blocks of control flow ops is not currently handled in subgraph_utils, is this understanding right ?

Lets take this graph for example

graph(%x.1 : Float(*)):
  %18 : Float(1) = prim::Constant[value={1}]()
  %16 : int[] = prim::Constant[value=[1]]()
  %1 : None = prim::Constant()
  %2 : int = prim::Constant[value=1]() # script_module.py:64:22
  %3 : int = prim::Constant[value=2]() # script_module.py:68:20
  %4 : int = prim::Constant[value=3]() # script_module.py:71:18
  %ret.1 : Double(*) = aten::zeros(%16, %1, %1, %1, %1) # script_module.py:64:10
  %7 : Bool(*) = aten::eq(%x.1, %2) # script_module.py:65:7
  %8 : bool = aten::Bool(%7) # script_module.py:65:7
  %ret : Tensor(*) = prim::If(%8) # script_module.py:65:4
    block0():
      %12 : Tensor = aten::add_(%ret.1, %18, %2) # script_module.py:67:8
      %ret.4 : Double(*) = aten::mul(%ret.1, %3) # script_module.py:68:14
      %ret.7 : Double(*) = aten::add(%ret.4, %18, %2) # script_module.py:69:14
      -> (%ret.7)
    block1():
      %ret.9 : Float(*) = aten::add(%x.1, %4, %2) # script_module.py:71:14
      -> (%ret.9)
  return (%ret)

Here, the first node aten::add_() inside block0() of prim::If has one of its inputs %ret.1 which is coming from outside the prim::If. But when cloning this node, during mergeNodeIntoSubgraph its not able to find the metadata of this input node. I think its because when cloning, the value_map only has the inputs in block scope and not the graph scope. I am not very familiar with the PT graph manipulation to understand the reason why only block’s inputs are added to value_map in ir.cpp and only the prim::If’s inputs are added to the value map in subgraph_utils . The error seems to come because the value_map(i) in ir.cpp returns a NULL as the ret.1 is not in the value_map.

Is there an example in code where fusion of control flow ops is handled? Also, please correct me if my understanding is wrong.