Torch.export + fx.GraphModule: Unable to lift constants (parameters/buffers) to placeholders for memory sharing across modules

Issue Description:

I am trying to lift Parameter and buffer nodes from get_attr to placeholder in a graph module obtained via torch.export, so that multiple modules can share the same tensor memory at inference time.

However, after manually modifying the graph and replacing get_attr nodes with placeholders, I get a ValueError due to a mismatch between the original input tree spec and the new one.

Minimal Reproducible Example:

Python

import torch
from torch.export import export

# 1. Define a simple model
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 64)
        self.bn = torch.nn.BatchNorm1d(64)
        self.fc2 = torch.nn.Linear(64, 5)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNet()
x = torch.randn(3, 10)

# 2. Export the model
ep = export(model, (x,))
gm = ep.module()

# 3. Lift parameters and buffers to placeholders
state_dict = dict(gm.named_buffers())
state_dict.update(gm.named_parameters())

lifted_names = []
for node in list(gm.graph.nodes):
    if node.op == 'get_attr':
        name = node.target
        if name not in state_dict:
            continue
        with gm.graph.inserting_before(next(iter(gm.graph.nodes))):
            ph = gm.graph.placeholder(name.replace('.', '_') + '_ph')
        node.replace_all_uses_with(ph)
        lifted_names.append(name)
        gm.graph.erase_node(node)

gm.graph.lint()
gm.recompile()

# 4. Construct new inputs
examples = (
    state_dict['bn.num_batches_tracked'],
    state_dict['bn.running_var'],
    state_dict['bn.running_mean'],
    state_dict['fc2.bias'],
    state_dict['fc2.weight'],
    state_dict['bn.bias'],
    state_dict['bn.weight'],
    state_dict['fc1.bias'],
    state_dict['fc1.weight'],
    torch.randn(3, 10),  # x
)

# 5. Try to run the modified graph
out = gm(*examples)

Error Message:

ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
  TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*, *, *, *, *, *, *, *, *, *]),
  TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.

Expected Behavior:

I expect to be able to lift constants (parameters/buffers) to placeholders so that:

  • I can pass them as inputs at runtimeļ¼›

  • Multiple graph modules can share the same tensor memory (e.g. for inference optimization or memory-constrained environments)

Request for Help / Suggestion:

Is there a recommended way to lift constants to placeholders in an exported graph module