Hello,
In the example for the replace pattern of torch.fx, function or replaced (torch.add
by torch.mul
).
This is very clear, however it is not clear to me if it is possible to replace modules as well, and if so, how to do it.
The following example failed for me with the error:
Traceback (most recent call last):
File "test.py", line 43, in <module>
torch.fx.subgraph_rewriter.replace_pattern(gm, pattern, replacement)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/subgraph_rewriter.py", line 201, in replace_pattern
pattern_graph = symbolic_trace(pattern).graph
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 606, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 355, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "test.py", line 33, in pattern
val1 = F.relu(lin1(x))
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 344, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 221, in call_module
module_qualified_name = self.path_of_module(m)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/automl/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 191, in path_of_module
raise NameError('module is not installed as a submodule')
NameError: module is not installed as a submodule
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fx
class TwoLayerSubNet(nn.Module):
def __init__(self, D_in=100, H=50, D_out=10):
super(TwoLayerSubNet, self).__init__()
self.linear1 = nn.Linear(D_in, H)
self.linear2 = nn.Linear(H, D_out)
def forward(self, x):
h_relu = F.relu(self.linear1(x))
y_pred = self.linear2(h_relu)
return y_pred
class TwoLayerNet(nn.Module):
def __init__(self, D_in=100, H=50, D_int=20, D_out=10):
super(TwoLayerNet, self).__init__()
self.sub1 = TwoLayerSubNet(D_in, H, D_int)
self.sub2 = TwoLayerSubNet(D_int, H, D_out)
def forward(self, x):
h_sub1 = self.sub1(x)
y_pred = self.sub2(x)
return y_pred
def pattern(x):
lin1 = torch.nn.Linear(30, 20)
lin2 = torch.nn.Linear(20, 10)
val1 = F.relu(lin1(x))
return F.relu(lin2(val1))
def replacement(x):
lin1 = torch.nn.Linear(30, 10)
return F.relu(lin1(x))
m = TwoLayerNet()
gm = torch.fx.symbolic_trace(m)
torch.fx.subgraph_rewriter.replace_pattern(gm, pattern, replacement)
Am I missing something?
Thanks