Issue Description:
I am trying to lift Parameter and buffer nodes from get_attr to placeholder in a graph module obtained via torch.export, so that multiple modules can share the same tensor memory at inference time.
However, after manually modifying the graph and replacing get_attr nodes with placeholders, I get a ValueError due to a mismatch between the original input tree spec and the new one.
Minimal Reproducible Example:
Python
import torch
from torch.export import export
# 1. Define a simple model
class SimpleNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 64)
self.bn = torch.nn.BatchNorm1d(64)
self.fc2 = torch.nn.Linear(64, 5)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
x = torch.randn(3, 10)
# 2. Export the model
ep = export(model, (x,))
gm = ep.module()
# 3. Lift parameters and buffers to placeholders
state_dict = dict(gm.named_buffers())
state_dict.update(gm.named_parameters())
lifted_names = []
for node in list(gm.graph.nodes):
if node.op == 'get_attr':
name = node.target
if name not in state_dict:
continue
with gm.graph.inserting_before(next(iter(gm.graph.nodes))):
ph = gm.graph.placeholder(name.replace('.', '_') + '_ph')
node.replace_all_uses_with(ph)
lifted_names.append(name)
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()
# 4. Construct new inputs
examples = (
state_dict['bn.num_batches_tracked'],
state_dict['bn.running_var'],
state_dict['bn.running_mean'],
state_dict['fc2.bias'],
state_dict['fc2.weight'],
state_dict['bn.bias'],
state_dict['bn.weight'],
state_dict['fc1.bias'],
state_dict['fc1.weight'],
torch.randn(3, 10), # x
)
# 5. Try to run the modified graph
out = gm(*examples)
Error Message:
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*, *, *, *, *, *, *, *, *, *]),
TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.
Expected Behavior:
I expect to be able to lift constants (parameters/buffers) to placeholders so that:
-
I can pass them as inputs at runtimeļ¼
-
Multiple graph modules can share the same tensor memory (e.g. for inference optimization or memory-constrained environments)
Request for Help / Suggestion:
Is there a recommended way to lift constants to placeholders in an exported graph module