I follow this tutorial torch.fx — PyTorch 1.12 documentation and try to append a sigmoid function after relu.
The given code in tutorial is here:
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
And I write following code to implement this function.
import torch
from torch import fx
import copy
# Sample module
class M(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
m = M()
gm = fx.symbolic_trace(m)
for node in gm.graph.nodes:
if node.op == 'call_function':
if node.target == torch.relu:
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(torch.sigmoid, args=(node,))
node.replace_all_uses_with(new_node)
gm.recompile()
print(gm.code)
However, the recompiled code is actually wrong:
def forward(self, x, y):
relu = torch.relu(x); x = None
sigmoid = torch.sigmoid(sigmoid)
return sigmoid
The relu node in the args of sigmoid node was also replaced by sigmoid node. And I fix this using copy.deepcopy with following code:
import torch
from torch import fx
import copy
# Sample module
class M(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
m = M()
gm = fx.symbolic_trace(m)
for node in gm.graph.nodes:
if node.op == 'call_function':
if node.target == torch.relu:
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(torch.sigmoid, args=(copy.deepcopy(node),))
node.replace_all_uses_with(new_node)
gm.recompile()
print(gm.code)
This gives correct code:
def forward(self, x):
relu = torch.relu(x); x = None
sigmoid = torch.sigmoid(relu); relu = None
return sigmoid
So I wonder what’s the actually recommended way to append node?