Shape inference in FX Graphs

Hi, I’m studying the torch dynamo graph export workflow and realized FX graphs have symbolic function arguments. For instance,

class MatMulNet(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def forward(self, x, y):
        return torch.matmul(x, y)

model = MatMulNet()

torch._dynamo.export(model, aten_graph=True)(
    torch.rand(5,5), torch.rand(5,5))

produces the following Aten graph.

GraphModule()
def forward(self, x, y):
    arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
    mm_default = torch.ops.aten.mm.default(arg0, arg1);  arg0 = arg1 = None
    return pytree.tree_unflatten([mm_default], self._out_spec)

As you can see, even though I’m passing initialized arguments to the export function, the traced FX graph has symbolic args x, y.
I’m curious as to where shape inference happens and placeholder nodes get their specific shapes.

When I generate debug traces with TORCH_COMPILE_DEBUG=1, I see concrete type information in fx_graph_readable.py

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[5, 5], arg1_1: f32[5, 5]):
        # File: /home/pytorch-cuda/my_scripts/mytorch.py:35, code: return torch.matmul(x, y)
        mm: f32[5, 5] = torch.ops.aten.mm.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
        return (mm,)

Update: I see that meta attribute carries this information for a torch.fx.Node object using a FakeTensor.

1 Like