How to insert `torch.nn.Linear` in the existing graph

Hi, there

I am new to torch.fx but I like it! I met a problem when I attempted to insert an operator into an existing GraphModule.

If I insert a torch.relu at the end of the graph, the following program runs successfully. But I fail to insert a torch.nn.Linear into the graph.

from torch.fx import symbolic_trace
import torch.fx as fx
import torch.nn as nn

import torch.fx

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(100, 50)

    def forward(self, x):
        x = self.fc1(x)
        return x

toy_net = Net()
toy_net_gm = torch.fx.symbolic_trace(toy_net)
print(toy_net_gm.graph)
print(toy_net_gm.code)

def manipulate_graph(graph_module, new_function, module_type):
    for last_node in graph_module.graph.nodes:
        if last_node.op == "output":
            new_func_node = graph_module.graph.call_function(new_function, args=(last_node.prev,))
            with graph_module.graph.inserting_before(last_node):
                last_node.replace_all_uses_with(new_func_node)
            break
    else:
        raise Exception("No found output node")
    graph_module.graph.erase_node(last_node)
    graph_module.graph.output(new_func_node, type_expr=getattr(new_func_node, 'type', None))
    graph_module.graph.lint()
    return fx.GraphModule(module_type, graph_module.graph)

new_gm = manipulate_graph(toy_net_gm, torch.relu, toy_net)
new_gm.recompile()
print(new_gm.graph)
print(new_gm.code)

I tried to insert a torch.nn.Linear(50,10) into the graph but got an exception:

  File "E:/pyworkspace/pytorch/add_linear.py", line 58, in <module>
    new_gm = manipulate_graph(toy_net_gm, torch.nn.Linear(50,10), toy_net)
  File "E:/pyworkspace/pytorch/add_linear.py", line 47, in manipulate_graph
    new_func_node = graph_module.graph.call_function(new_function, args=(last_node.prev,))
  File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 728, in call_function
    return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
  File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 428, in create_node
    candidate = name if name is not None else self._target_to_str(target)
  File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 784, in _target_to_str
    op = target.__name__
  File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\nn\modules\module.py", line 1185, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Linear' object has no attribute '__name__'. Did you mean: '__ne__'?

The question is how to correctly insert a torch.nn.XXX operator in the graph ?(instead of the function like torch.relu)

Solved!

  1. Firstly, use toy_net_gm.add_module("fc1",torch.nn.Linear(50,10)) to define a module in toy_net_gm
  2. Then use graph_module.graph.call_module("fc1", args=(last_node.prev,)) to insert a Linear operator