Hi,
Following this discussion about inserting a block instead of a node, I would like to rebuild a model using Proxy
.
However I am unable to build a model with calls to modules, here is an example:
import torch.nn as nn
import torch.fx as fx
module = nn.Module()
submodule = nn.Linear(10, 1)
module.add_module('sub', submodule)
graph = fx.Graph()
raw = graph.placeholder('x')
x = fx.Proxy(raw)
y = submodule(x)
output = graph.output(y.node)
gm = fx.GraphModule(module, graph)
result:
Traceback (most recent call last):
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-37-1d2b0a63ef08>", line 1, in <module>
gm = fx.GraphModule(module, graph)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 194, in __init__
self.graph = graph
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 995, in __setattr__
object.__setattr__(self, name, value)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 217, in graph
self.recompile()
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 297, in recompile
cls.forward = _forward_from_src(self._code)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 41, in _forward_from_src
exec_with_source(src, gbls)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 26, in exec_with_source
exec(compile(src, key, 'exec'), globals)
File "<eval_with_key_10>", line 3
linear_1 = torch.nn.functional.linear(x, Parameter containing:
^
SyntaxError: invalid syntax
When I print the graph I obtain this:
graph(x):
%linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x, Parameter containing:
tensor([[-0.1319, -0.2979, 0.2596, -0.1899, 0.1811, -0.1425, 0.1439, -0.0744,
-0.0514, 0.1271]], requires_grad=True)), kwargs = {bias: None})
return linear_1
Any help would be appreciated.
Thank you.