Torch.fx replace modules

Hello,

In the example for the replace pattern of torch.fx, function or replaced (torch.add by torch.mul).
This is very clear, however it is not clear to me if it is possible to replace modules as well, and if so, how to do it.

The following example failed for me with the error:

Traceback (most recent call last):
  File "test.py", line 43, in <module>
    torch.fx.subgraph_rewriter.replace_pattern(gm, pattern, replacement)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/subgraph_rewriter.py", line 201, in replace_pattern
    pattern_graph = symbolic_trace(pattern).graph
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 606, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 355, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "test.py", line 33, in pattern
    val1 = F.relu(lin1(x))
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 344, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 221, in call_module
    module_qualified_name = self.path_of_module(m)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 191, in path_of_module
    raise NameError('module is not installed as a submodule')
NameError: module is not installed as a submodule
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fx


class TwoLayerSubNet(nn.Module):
    def __init__(self, D_in=100, H=50, D_out=10):
        super(TwoLayerSubNet, self).__init__()
        self.linear1 = nn.Linear(D_in, H)
        self.linear2 = nn.Linear(H, D_out)

    def forward(self, x):
            h_relu = F.relu(self.linear1(x))
            y_pred = self.linear2(h_relu)
            return y_pred


class TwoLayerNet(nn.Module):
    def __init__(self, D_in=100, H=50, D_int=20, D_out=10):
        super(TwoLayerNet, self).__init__()
        self.sub1 = TwoLayerSubNet(D_in, H, D_int)
        self.sub2 = TwoLayerSubNet(D_int, H, D_out)

    def forward(self, x):
            h_sub1 = self.sub1(x)
            y_pred = self.sub2(x)
            return y_pred

def pattern(x):
    lin1 = torch.nn.Linear(30, 20)
    lin2 = torch.nn.Linear(20, 10)
    val1 = F.relu(lin1(x))
    return F.relu(lin2(val1))


def replacement(x):
    lin1 = torch.nn.Linear(30, 10)
    return F.relu(lin1(x))

m = TwoLayerNet()
gm = torch.fx.symbolic_trace(m)
torch.fx.subgraph_rewriter.replace_pattern(gm, pattern, replacement)

Am I missing something?

Thanks

Do you solve it?I have the same problem in resnet18.

I am also curious and I still have not figured out how to do that.

Right now I am looking at test code here: pytorch/test_fx.py at ebf7a4f843ac2c34f0f0765c57f8c0bb8a194686 · pytorch/pytorch · GitHub

There is some functionality to add and delete some atomic functions, however, I am not sure if it is possible to replace bigger classes.