Hi,
I am getting this warning when inserting a call_module
node even though I called GraphModule.add_submodule
(or .add_module
, or even both):
/(...)/anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/graph.py:606: UserWarning: Attempted to insert a call_module Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule
warnings.warn("Attempted to insert a call_module Node with "
Here is the minimal example that produces the warning:
import torch
import torch.fx as fx
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class MyOtherModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(24, 32, 3)
def forward(self, x):
x = nn.functional.relu(self.conv(x))
def insert_after(mod: fx.GraphModule, from_idx: int, insertme: fx.GraphModule):
mod.add_submodule('inserted', insertme)
cutoff_node: fx.Node = list(mod.graph.nodes)[from_idx]
next_node: fx.Node = list(mod.graph.nodes)[from_idx+1]
with mod.graph.inserting_after(cutoff_node):
new_node = model.graph.call_module('inserted', (cutoff_node,), {})
next_node.replace_input_with(cutoff_node, new_node)
mod.delete_all_unused_submodules()
mod.graph.eliminate_dead_code()
mod.recompile()
mod.graph.lint()
return mod
mod = fx.symbolic_trace(MyModule())
other = fx.symbolic_trace(MyOtherModule())
new_mod = insert_after(mod, 0, other)
As you can see I am inserting a call_module
node on a module I just added with GraphModule.add_submodule
, but the warning seems not to see it.
Have I missed something?