Proxy and module calls

Hi,

Following this discussion about inserting a block instead of a node, I would like to rebuild a model using Proxy.
However I am unable to build a model with calls to modules, here is an example:

import torch.nn as nn
import torch.fx as fx

module = nn.Module()
submodule = nn.Linear(10, 1)
module.add_module('sub', submodule)

graph = fx.Graph()
raw = graph.placeholder('x')
x = fx.Proxy(raw)
y = submodule(x)
output = graph.output(y.node)
gm = fx.GraphModule(module, graph)

result:

Traceback (most recent call last):

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

  File "<ipython-input-37-1d2b0a63ef08>", line 1, in <module>
    gm = fx.GraphModule(module, graph)

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 194, in __init__
    self.graph = graph

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 995, in __setattr__
    object.__setattr__(self, name, value)

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 217, in graph
    self.recompile()

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 297, in recompile
    cls.forward = _forward_from_src(self._code)

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 41, in _forward_from_src
    exec_with_source(src, gbls)

  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 26, in exec_with_source
    exec(compile(src, key, 'exec'), globals)

  File "<eval_with_key_10>", line 3
    linear_1 = torch.nn.functional.linear(x, Parameter containing:
                                                                ^
SyntaxError: invalid syntax

When I print the graph I obtain this:

graph(x):
    %linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x, Parameter containing:
tensor([[-0.1319, -0.2979,  0.2596, -0.1899,  0.1811, -0.1425,  0.1439, -0.0744,
         -0.0514,  0.1271]], requires_grad=True)), kwargs = {bias: None})
    return linear_1

Any help would be appreciated.
Thank you.

This seems to happen with modules containing parameters.
Is it a bug?
If not then how should I handle them?

The example in the github repository doesn’t showcase the use of proxy with these type of modules.

Thank you in advance.

Hi @nscotto,

To resolve modules/parameters/buffers while tracing, a Proxy must be initialized with an underlying Tracer instance. It’s kind-of hacky, but the quickest way I found to fix your code was to write it like this:

import torch
import torch.nn as nn
import torch.fx as fx

module = nn.Module()
submodule = nn.Linear(10, 1)
module.add_module('sub', submodule)

graph = fx.Graph()
raw = graph.placeholder('x')
tracer = torch.fx.Tracer()
tracer.root = module
tracer.graph = graph
x = fx.Proxy(raw, tracer)
y = submodule(x)
output = graph.output(y.node)
gm = fx.GraphModule(module, graph)

print(gm.code)

I think a more idiomatic way of doing this would be to use a custom Tracer that specifies modules to not treat as leaves:

import torch.fx

class SomeModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)

traced = torch.fx.symbolic_trace(SomeModule())
print(traced.code)
"""
def forward(self, x):
    linear = self.linear(x);  x = None
    return linear
"""

# Now we can use a custom tracer to set the linear instance as a non-leaf module
# and thus trace into it

class InlineLinearTracer(torch.fx.Tracer):
    def __init__(self, linears_to_inline):
        super().__init__()
        self.linears_to_inline = linears_to_inline

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
        if m in self.linears_to_inline:
            return False
        return super().is_leaf_module(m, module_qualified_name)


tracer = InlineLinearTracer([traced.linear])
graph = tracer.trace(traced)
inlined_gm = torch.fx.GraphModule(traced, graph)
print(inlined_gm.code)
"""
def forward(self, x):
    linear_weight = self.linear.weight
    linear_bias = self.linear.bias
    linear = torch.nn.functional.linear(x, linear_weight, bias = linear_bias);  x = linear_weight = linear_bias = None
    return linear
"""

Thank you for the answer.
Let me add a bit more context to my problem.

I would like to replace a block by another block (e.g. a backbone), this block eventually have leaking connections (skip connections from inside the block to outside the block).

Previously I was doing it by adding the new block using add_module as a single Module and then adding a node to the graph (using graph.inserting_before(node)), and erasing every nodes from the replaced block.

The problem with this approach is that it didn’t allow me to remap leaking connections because I was inserting the new block as a single module (as if it was a single layer).
So I would like to insert the new block “layer by layer”/node by node, and I am looking for a solution to do it.

The solution proposed in this discussion is to use Proxy to recreate the graph, but I am struggling to use proxy with modules, hence this post.
While the proposed solution solve the small example I gave, I am not sure how to use it in my context, that is merging two Module together/replacing a block by another block.

So right now I am trying to rebuild a module as a submodule, I think it’s a progression towards my goal of being able to replace a backbone and handle intermediary skip connections.

I have some code for creating a new module that mimics another module (but use it as a submodule):

Here is the code;

def rebuild_module_as_submodule(module: fx.GraphModule):
    module = copy.deepcopy(module)

    new_module = nn.Module()
    new_module.add_module('mod', module)

    graph = fx.Graph()
    raw1 = graph.placeholder('x')

    tracer = torch.fx.Tracer()
    tracer.root = new_module
    tracer.graph = graph

    x = fx.Proxy(raw1)

    nodes = list(module.graph.nodes)
    node2proxy = {raw1: x, nodes[0]: x}

    def nodeargs2proxy(args):
        if isinstance(args, (list, tuple)):
            return tuple(nodeargs2proxy(arg) for arg in args)
        elif isinstance(args, dict):
            return {nodeargs2proxy(key): nodeargs2proxy(val) for key, val in args.items()}
        elif isinstance(args, fx.Node):
            return node2proxy[args]
        else:
            return args

    submodules = dict(new_module.named_modules())
    
        for node in nodes[1:-1]:
        if node.op == 'call_module':
            args = nodeargs2proxy(node.args)
            kwargs = nodeargs2proxy(node.kwargs)
            node2proxy[node] = submodules[f'mod.{node.target}'](*args, **kwargs)
        elif node.op == 'call_function':
            args = nodeargs2proxy(node.args)
            kwargs = nodeargs2proxy(node.kwargs)
            node2proxy[node] = node.target(*args, **kwargs)
        elif node.op == 'call_method':
            pass
            self_obj, *args = nodeargs2proxy(node.args)
            kwargs = nodeargs2proxy(node.kwargs)
            node2proxy[node] = getattr(self_obj, node.target)(*args, **kwargs)
        else:
            raise NotImplementedError(f'Node op: op={node.op}, name={node.name}, tgt={node.target}, args={node.args}, kwargs={node.kwargs}')

    last_proxy : fx.Proxy = node2proxy[node]
    graph.output(last_proxy.node)

    return new_module, graph

def test_2():
    submodule = nn.Linear(10, 1)
    module = fx.symbolic_trace(nn.Sequential(submodule))
    new_module, graph = rebuild_module_as_submodule(module)
    print(new_module)
    print(graph)
    gm = fx.GraphModule(new_module, graph)
    print(gm)

test_2()

I am still getting the same error as before, even though I created an instance of a Tracer and assignated the module and graph I am working with.

Here is the result of the above code:

{'': Module(
  (mod): GraphModule(
    (0): Linear(in_features=10, out_features=1, bias=True)
  )
), 'mod': GraphModule(
  (0): Linear(in_features=10, out_features=1, bias=True)
), 'mod.0': Linear(in_features=10, out_features=1, bias=True)}
Module(
  (mod): GraphModule(
    (0): Linear(in_features=10, out_features=1, bias=True)
  )
)
graph(x):
    %linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x, Parameter containing:
tensor([[-0.3050,  0.0150, -0.2705, -0.0750, -0.0539, -0.0999,  0.0665, -0.0039,
          0.3120, -0.2078]], requires_grad=True)), kwargs = {bias: Parameter containing:
tensor([0.1116], requires_grad=True)})
    return linear_1
Traceback (most recent call last):
  File "test.py", line 91, in <module>
    test_2()
  File "test.py", line 88, in test_2
    gm = fx.GraphModule(new_module, graph)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 194, in __init__
    self.graph = graph
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 995, in __setattr__
    object.__setattr__(self, name, value)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 217, in graph
    self.recompile()
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 297, in recompile
    cls.forward = _forward_from_src(self._code)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 41, in _forward_from_src
    exec_with_source(src, gbls)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/graph_module.py", line 26, in exec_with_source
    exec(compile(src, key, 'exec'), globals)
  File "<eval_with_key_2>", line 3
    linear_1 = torch.nn.functional.linear(x, Parameter containing:
                                                                ^
SyntaxError: invalid syntax

It is still this problem of parameters that are missing from the module, and I don’t know how to handle it in my case.
Do you have any idea how to fix this?