Replace node error in torch.fx tutorial

I follow this tutorial torch.fx — PyTorch 1.12 documentation and try to append a sigmoid function after relu.
The given code in tutorial is here:

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

And I write following code to implement this function.

import torch
from torch import fx
import copy

# Sample module
class M(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x)

m = M()
gm = fx.symbolic_trace(m)
for node in gm.graph.nodes:
    if node.op == 'call_function':
        if node.target == torch.relu:
            with gm.graph.inserting_after(node):
                new_node = gm.graph.call_function(torch.sigmoid, args=(node,))
                node.replace_all_uses_with(new_node)

gm.recompile()
print(gm.code)

However, the recompiled code is actually wrong:

def forward(self, x, y):
    relu = torch.relu(x);  x = None
    sigmoid = torch.sigmoid(sigmoid)
    return sigmoid

The relu node in the args of sigmoid node was also replaced by sigmoid node. And I fix this using copy.deepcopy with following code:

import torch
from torch import fx
import copy

# Sample module
class M(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x)

m = M()
gm = fx.symbolic_trace(m)
for node in gm.graph.nodes:
    if node.op == 'call_function':
        if node.target == torch.relu:
            with gm.graph.inserting_after(node):
                new_node = gm.graph.call_function(torch.sigmoid, args=(copy.deepcopy(node),))
                node.replace_all_uses_with(new_node)

gm.recompile()
print(gm.code)

This gives correct code:

def forward(self, x):
    relu = torch.relu(x);  x = None
    sigmoid = torch.sigmoid(relu);  relu = None
    return sigmoid

So I wonder what’s the actually recommended way to append node?

the same issue, I suppose the reason is that:
after run

new_node = gm.graph.call_function(torch.sigmoid, args=(node,))

the new_node has been append to users of node, so the next statement

node.replace_all_uses_with(new_node)

will change input of all node’s users, include new_node, so the new_node input will be new_node it self.

And I solve it by set the input of new_node, change it form new_node → node

node.replace_all_uses_with(new_node)
new_node.replace_input_with(new_node, node)