FX graph resorting a torch model's positional args based on execution order

It seems like the order of the args in an FX graph are dependent on the order of execution of the args, rather than however the args were expected in the model definition.

Is this expected behaviour? And if it is, how should I acquire my FX graph without running into this behaviour? Or how could I circumvent this behaviour?

I’ve attached an example below

``````def model(x_1, x_2, x_3):
x_5 = torch.nn.functional.conv2d(x_1, x_2, x_3)
return x_5

def model_1(x_1, x_2, x_3):
x_4 = x_2
x_5 = torch.nn.functional.conv2d(x_1, x_4, x_3)
return x_5

x_1 = torch.randn([4, 6, 8, 4])
x_2 = torch.randn([4, 6, 2, 2])
x_3 = torch.randn([4])

def get_fx(model, args):
graphs = []
def some_backend(graph_module, sample_inputs):
nonlocal graphs
graphs.append(graph_module)
return graph_module

torch._dynamo.reset()
torch.compile(model, backend=some_backend)(*args)
return graphs

graph_module = get_fx(model, [x_1, x_2, x_3])
output = graph_module[0](x_1, x_2, x_3)

graph_module_1 = get_fx(model_1, [x_1, x_2, x_3])
output_1 = graph_module_1[0](x_1, x_2, x_3)
``````

In the above code, I have 2 different models. The only difference is that in model_1, the argument x_2 is used before x_1.

Below, get_fx() tries to generate, extract, and print the FX graphs that get created from the models.

The print for graph_module shows

``````def forward(self, L_x_1_ : torch.Tensor, L_x_2_ : torch.Tensor, L_x_3_ : torch.Tensor):
l_x_1_ = L_x_1_
l_x_2_ = L_x_2_
l_x_3_ = L_x_3_

x_5 = torch.conv2d(l_x_1_, l_x_2_, l_x_3_);  l_x_1_ = l_x_2_ = l_x_3_ = None
return (x_5,)
``````

which is expected.

The print for graph_module_1 shows

``````def forward(self, L_x_2_ : torch.Tensor, L_x_1_ : torch.Tensor, L_x_3_ : torch.Tensor):
x_4 = L_x_2_
l_x_1_ = L_x_1_
l_x_3_ = L_x_3_

x_5 = torch.conv2d(l_x_1_, x_4, l_x_3_);  l_x_1_ = x_4 = l_x_3_ = None
return (x_5,)
``````

We can see that x_2 is expected before x_1.

As a result, I can call `graph_module[0](x_1, x_2, x_3)` without issue, but calling `graph_module_1[0](x_1, x_2, x_3)` results in an error, because it’s getting the inputs in the wrong order.

This is indeed intended behaviour and torch.export will maintain the correct input mapping.