Dynamo Graph Capture automatic transform from `get_attr` to `placeholder`?

I’m exploring the differences in graph capturing behavior between using torchdynamo and torch.fx symbolic_trace. Specifically, I’ve noticed that when tracing models using Dynamo, the get_attr nodes often get automatically converted to placeholder nodes.

Here’s a simple runnable demo that illustrates what I’m observing:

import torch
from torch.fx import GraphModule, Tracer, Graph
from torch.export import export

class CustomModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))

    def forward(self, x):
        x = torch.mm(x, self.w1)
        # x = gelu(x)  # Uncomment for non-linear operations
        x = x.sum()
        return (x,)

if __name__ == "__main__":
    torch.set_default_dtype(torch.bfloat16)
    with torch.device("meta"):
        hidden_size = 1024
        model = CustomModel(hidden_size)
        inp = torch.zeros(2, hidden_size, requires_grad=True)
        
        tracer = Tracer()
        graph = tracer.trace(model)
        graph.print_tabular()
        
        exported_program: torch.export.ExportedProgram = export(model, args=(inp,))
        gm = exported_program.graph_module
        gm.graph.print_tabular()

The output using torch.fx directly vs. using export (which I presume uses Dynamo internally) shows different behaviors. In the FX trace, parameters are retained as get_attr, whereas in the Dynamo-based trace, they are converted to placeholder.

Output using FX:

opcode         name    target                                                 args         kwargs
-------------  ------  -----------------------------------------------------  -----------  --------
placeholder    x       x                                                      ()           {}
get_attr       w1      w1                                                     ()           {}
call_function  mm      <built-in method mm of type object at 0x7f661c873500>  (x, w1)      {}
call_method    sum_1   sum                                                    (mm,)        {}
output         output  output                                                 ((sum_1,),)  {}

Output using Dynamo:

opcode         name    target            args         kwargs
-------------  ------  ----------------  -----------  --------
placeholder    p_w1    p_w1              ()           {}
placeholder    x       x                 ()           {}
call_function  mm      aten.mm.default   (x, p_w1)    {}
call_function  sum_1   aten.sum.default  (mm,)        {}
output         output  output            ((sum_1,),)  {}

I am curious why this happens, and if there’s a way to control or prevent this behavior when using Dynamo. Any insights or recommendations on how to handle this discrepancy would be greatly appreciated.

Update:
It seems it is dynamo’s feature. And for the meta(shape/dytpe…) information, dynamo graph saves them in node.meta.

Update:

Dynamo will automatically functionalize the graph, meaning that all input parameters and buffers are treated as graph inputs, and the entire graph is seen as a large forward function.

If you use export to get an ExportedProgram type, you can then call torch.export.unflatten(exported) to transform the graph back into an UnflattenedModule. In this case, the graph structure will have different hierarchical levels, such as getattr and call_module.

However, if you want to capture both the forward and backward (joint) graphs, you would use aot_export_module, which returns a torch.fx.GraphModule and a GraphSignature. This approach uses Dynamo and AOTAutoGrad to capture the graph at the lower level.

At this point, if we want to get an ExportedProgram and use unflatten, it is not feasible because the TreeSpec information in the GraphSignature is actually empty.

Nonetheless, the GraphSignature contains information like inputs_to_parameters, so we can still manually obtain the source of the placeholders, but we cannot rebuild the submodule structure.