There is an example on this page here inserting a new node of call_function type into fx.graph.Graph
.
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
The functional callable is torch.relu
which can be explicitly mentioned in the above python script. However, in many cases call_function types are built-in types that cannot be explicitly mentioned. In the following example, Iād like to gather specific channels from a 4-D tensor:
class GatherExample (nn.Module):
def __init__(self) -> None:
super(GatherExample, self).__init__()
self.channel_idxs = [0, 2, 3, 5, 10]
def forward(self, x):
out = x[:, self.channel_idxs]
return out
opcode name target args kwargs
------------- ------- --------------------------- ------------------------------------------------ --------
placeholder x x () {}
call_function getitem <built-in function getitem> (x, (slice(None, None, None), [0, 2, 3, 5, 10])) {}
output output output (getitem,) {}
I can indirectly get an instance of the built-in callable type by running this small piece of module definition here. But this way is still a bit messy for me. So, is there a better way to insert or delete these built-in call_function types in torch.fx
?