How can I insert a node of built-in call_function type into fx.graph.Graph?

There is an example on this page here inserting a new node of call_function type into fx.graph.Graph.

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

The functional callable is torch.relu which can be explicitly mentioned in the above python script. However, in many cases call_function types are built-in types that cannot be explicitly mentioned. In the following example, Iā€™d like to gather specific channels from a 4-D tensor:

class GatherExample (nn.Module):
    def __init__(self) -> None:
        super(GatherExample, self).__init__()
        self.channel_idxs = [0, 2, 3, 5, 10]

    def forward(self, x):
        out = x[:, self.channel_idxs]

        return out
opcode         name     target                       args                                              kwargs
-------------  -------  ---------------------------  ------------------------------------------------  --------
placeholder    x        x                            ()                                                {}
call_function  getitem  <built-in function getitem>  (x, (slice(None, None, None), [0, 2, 3, 5, 10]))  {}
output         output   output                       (getitem,)                                        {}

I can indirectly get an instance of the built-in callable type by running this small piece of module definition here. But this way is still a bit messy for me. So, is there a better way to insert or delete these built-in call_function types in torch.fx?

you should use call_module instead

gm = symbolic_trace(model)
gm.add_module("gather_op", GatherExample())

with gm.graph.inserting_after(ori_node):
     new_node = gm.graph.call_module("gather_op")
1 Like

getitem can be import from operator

from operator import getitem
traced.graph.call_function(getitem, args=(node, idx))