FX graph resorting a torch model's positional args based on execution order

It seems like the order of the args in an FX graph are dependent on the order of execution of the args, rather than however the args were expected in the model definition.

Is this expected behaviour? And if it is, how should I acquire my FX graph without running into this behaviour? Or how could I circumvent this behaviour?

I’ve attached an example below

def model(x_1, x_2, x_3):
    x_5 = torch.nn.functional.conv2d(x_1, x_2, x_3)
    return x_5

def model_1(x_1, x_2, x_3):
    x_4 = x_2
    x_5 = torch.nn.functional.conv2d(x_1, x_4, x_3)
    return x_5

x_1 = torch.randn([4, 6, 8, 4])
x_2 = torch.randn([4, 6, 2, 2])
x_3 = torch.randn([4])

def get_fx(model, args):
    graphs = []
    def some_backend(graph_module, sample_inputs):
        nonlocal graphs
        graph_module.print_readable()
        graphs.append(graph_module)
        return graph_module

    torch._dynamo.reset()
    torch.compile(model, backend=some_backend)(*args)
    return graphs

graph_module = get_fx(model, [x_1, x_2, x_3])
output = graph_module[0](x_1, x_2, x_3)

graph_module_1 = get_fx(model_1, [x_1, x_2, x_3])
output_1 = graph_module_1[0](x_1, x_2, x_3)

In the above code, I have 2 different models. The only difference is that in model_1, the argument x_2 is used before x_1.

Below, get_fx() tries to generate, extract, and print the FX graphs that get created from the models.

The print for graph_module shows

def forward(self, L_x_1_ : torch.Tensor, L_x_2_ : torch.Tensor, L_x_3_ : torch.Tensor):
    l_x_1_ = L_x_1_
    l_x_2_ = L_x_2_
    l_x_3_ = L_x_3_
        
    x_5 = torch.conv2d(l_x_1_, l_x_2_, l_x_3_);  l_x_1_ = l_x_2_ = l_x_3_ = None
    return (x_5,)

which is expected.

The print for graph_module_1 shows

def forward(self, L_x_2_ : torch.Tensor, L_x_1_ : torch.Tensor, L_x_3_ : torch.Tensor):
    x_4 = L_x_2_
    l_x_1_ = L_x_1_
    l_x_3_ = L_x_3_
        
    x_5 = torch.conv2d(l_x_1_, x_4, l_x_3_);  l_x_1_ = x_4 = l_x_3_ = None
    return (x_5,)

We can see that x_2 is expected before x_1.

As a result, I can call graph_module[0](x_1, x_2, x_3) without issue, but calling graph_module_1[0](x_1, x_2, x_3) results in an error, because it’s getting the inputs in the wrong order.

This is indeed intended behaviour and torch.export will maintain the correct input mapping.