Hi, there
I am new to torch.fx
but I like it! I met a problem when I attempted to insert an operator into an existing GraphModule.
If I insert a torch.relu
at the end of the graph, the following program runs successfully. But I fail to insert a torch.nn.Linear
into the graph.
from torch.fx import symbolic_trace
import torch.fx as fx
import torch.nn as nn
import torch.fx
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(100, 50)
def forward(self, x):
x = self.fc1(x)
return x
toy_net = Net()
toy_net_gm = torch.fx.symbolic_trace(toy_net)
print(toy_net_gm.graph)
print(toy_net_gm.code)
def manipulate_graph(graph_module, new_function, module_type):
for last_node in graph_module.graph.nodes:
if last_node.op == "output":
new_func_node = graph_module.graph.call_function(new_function, args=(last_node.prev,))
with graph_module.graph.inserting_before(last_node):
last_node.replace_all_uses_with(new_func_node)
break
else:
raise Exception("No found output node")
graph_module.graph.erase_node(last_node)
graph_module.graph.output(new_func_node, type_expr=getattr(new_func_node, 'type', None))
graph_module.graph.lint()
return fx.GraphModule(module_type, graph_module.graph)
new_gm = manipulate_graph(toy_net_gm, torch.relu, toy_net)
new_gm.recompile()
print(new_gm.graph)
print(new_gm.code)
I tried to insert a torch.nn.Linear(50,10)
into the graph but got an exception:
File "E:/pyworkspace/pytorch/add_linear.py", line 58, in <module>
new_gm = manipulate_graph(toy_net_gm, torch.nn.Linear(50,10), toy_net)
File "E:/pyworkspace/pytorch/add_linear.py", line 47, in manipulate_graph
new_func_node = graph_module.graph.call_function(new_function, args=(last_node.prev,))
File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 728, in call_function
return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 428, in create_node
candidate = name if name is not None else self._target_to_str(target)
File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\fx\graph.py", line 784, in _target_to_str
op = target.__name__
File "D:\Softwares\anaconda3\envs\library\lib\site-packages\torch\nn\modules\module.py", line 1185, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Linear' object has no attribute '__name__'. Did you mean: '__ne__'?
The question is how to correctly insert a torch.nn.XXX
operator in the graph ?(instead of the function like torch.relu
)