I am getting errors during a call to symbolic_trace
on a GraphModule
after having done modifications to it (inserting a call_module
node).
This is the exact error::
/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/graph.py:606: UserWarning: Attempted to insert a call_module Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule
warnings.warn("Attempted to insert a call_module Node with "
Traceback (most recent call last):
File "pytorch_issue.py", line 44, in <module>
new_mod = fx.symbolic_trace(mod)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 859, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 571, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "<eval_with_key_2>", line 3, in forward
File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/graph_module.py", line 513, in wrapped_call
raise e.with_traceback(None)
AttributeError: 'MyOtherModule' object has no attribute 'conv'
And here is the code:
import torch
import torch.fx as fx
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class MyOtherModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(24, 32, 3)
def forward(self, x):
x = nn.functional.relu(self.conv(x))
def insert_mod_after(mod: fx.GraphModule, from_idx: int, insertme: fx.GraphModule):
mod.add_submodule('inserted', insertme)
cutoff_node: fx.Node = list(mod.graph.nodes)[from_idx]
next_node: fx.Node = list(mod.graph.nodes)[from_idx+1]
with mod.graph.inserting_after(cutoff_node):
new_node = mod.graph.call_module('inserted', (cutoff_node,), {})
next_node.replace_input_with(cutoff_node, new_node)
mod.delete_all_unused_submodules()
mod.graph.eliminate_dead_code()
mod.recompile()
mod.graph.lint()
return mod
mod = fx.symbolic_trace(MyModule())
other = fx.symbolic_trace(MyOtherModule())
insert_mod_after(mod, 0, other)
new_mod = fx.symbolic_trace(mod)
It looks like it could be related to the warning, although I don’t understand why this warning is happening since I am calling add_submodule
already… I created another topic for the warning.
Thank you for reading