[SOLVED} Calling symbolic_trace after inserting a call_module node creates problems

I am getting errors during a call to symbolic_trace on a GraphModule after having done modifications to it (inserting a call_module node).
This is the exact error::

/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/graph.py:606: UserWarning: Attempted to insert a call_module Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule
  warnings.warn("Attempted to insert a call_module Node with "
Traceback (most recent call last):
  File "pytorch_issue.py", line 44, in <module>
    new_mod = fx.symbolic_trace(mod)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 859, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/symbolic_trace.py", line 571, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "<eval_with_key_2>", line 3, in forward
  File "/home/SERILOCAL/n.perto/.anaconda3/envs/test/lib/python3.7/site-packages/torch/fx/graph_module.py", line 513, in wrapped_call
    raise e.with_traceback(None)
AttributeError: 'MyOtherModule' object has no attribute 'conv'

And here is the code:

import torch
import torch.fx as fx
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class MyOtherModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(24, 32, 3)

    def forward(self, x):
        x = nn.functional.relu(self.conv(x))


def insert_mod_after(mod: fx.GraphModule, from_idx: int, insertme: fx.GraphModule):
    mod.add_submodule('inserted', insertme)
    cutoff_node: fx.Node = list(mod.graph.nodes)[from_idx]
    next_node: fx.Node = list(mod.graph.nodes)[from_idx+1]

    with mod.graph.inserting_after(cutoff_node):
        new_node = mod.graph.call_module('inserted', (cutoff_node,), {})

    next_node.replace_input_with(cutoff_node, new_node)

    mod.delete_all_unused_submodules()
    mod.graph.eliminate_dead_code()
    mod.recompile()
    mod.graph.lint()

    return mod


mod = fx.symbolic_trace(MyModule())
other = fx.symbolic_trace(MyOtherModule())
insert_mod_after(mod, 0, other)
new_mod = fx.symbolic_trace(mod)

It looks like it could be related to the warning, although I don’t understand why this warning is happening since I am calling add_submodule already… I created another topic for the warning.

Thank you for reading

It seems that the problems come from the call to mod.delete_all_unused_submodules().
I think it deletes submodules of the Sequential submodule added as they don’t appear in the graph even though the call_module node calling the Sequential submodule use them internally.
Putting it after the second symbolic_trace seems to solve the issue.

import torch
import torch.fx as fx
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class MyOtherModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(24, 32, 3)

    def forward(self, x):
        x = nn.functional.relu(self.conv(x))


def insert_mod_after(mod: fx.GraphModule, from_idx: int, insertme: fx.GraphModule):
    mod.add_submodule('inserted', insertme)
    cutoff_node: fx.Node = list(mod.graph.nodes)[from_idx]
    next_node: fx.Node = list(mod.graph.nodes)[from_idx+1]

    with mod.graph.inserting_after(cutoff_node):
        new_node = mod.graph.call_module('inserted', (cutoff_node,), {})

    next_node.replace_input_with(cutoff_node, new_node)

    mod.graph.eliminate_dead_code()
    mod.recompile()
    mod.graph.lint()
    mod = fx.symbolic_trace(mod)
    mod.delete_all_unused_submodules()
    return mod


mod = fx.symbolic_trace(MyModule())
other = fx.symbolic_trace(MyOtherModule())
new_mod = insert_mod_after(mod, 0, other)

I would like these type of ordering to appear in the documentation as it is not clear in which order to call Graph.eliminate_dead_code(), Graph.lint(), GraphModule.delete_all_unused_submodules(), GraphModule.recompile().