It seems like the order of the args in an FX graph are dependent on the order of execution of the args, rather than however the args were expected in the model definition.
Is this expected behaviour? And if it is, how should I acquire my FX graph without running into this behaviour? Or how could I circumvent this behaviour?
I’ve attached an example below
def model(x_1, x_2, x_3):
x_5 = torch.nn.functional.conv2d(x_1, x_2, x_3)
return x_5
def model_1(x_1, x_2, x_3):
x_4 = x_2
x_5 = torch.nn.functional.conv2d(x_1, x_4, x_3)
return x_5
x_1 = torch.randn([4, 6, 8, 4])
x_2 = torch.randn([4, 6, 2, 2])
x_3 = torch.randn([4])
def get_fx(model, args):
graphs = []
def some_backend(graph_module, sample_inputs):
nonlocal graphs
graph_module.print_readable()
graphs.append(graph_module)
return graph_module
torch._dynamo.reset()
torch.compile(model, backend=some_backend)(*args)
return graphs
graph_module = get_fx(model, [x_1, x_2, x_3])
output = graph_module[0](x_1, x_2, x_3)
graph_module_1 = get_fx(model_1, [x_1, x_2, x_3])
output_1 = graph_module_1[0](x_1, x_2, x_3)
In the above code, I have 2 different models. The only difference is that in model_1, the argument x_2 is used before x_1.
Below, get_fx() tries to generate, extract, and print the FX graphs that get created from the models.
The print for graph_module shows
def forward(self, L_x_1_ : torch.Tensor, L_x_2_ : torch.Tensor, L_x_3_ : torch.Tensor):
l_x_1_ = L_x_1_
l_x_2_ = L_x_2_
l_x_3_ = L_x_3_
x_5 = torch.conv2d(l_x_1_, l_x_2_, l_x_3_); l_x_1_ = l_x_2_ = l_x_3_ = None
return (x_5,)
which is expected.
The print for graph_module_1 shows
def forward(self, L_x_2_ : torch.Tensor, L_x_1_ : torch.Tensor, L_x_3_ : torch.Tensor):
x_4 = L_x_2_
l_x_1_ = L_x_1_
l_x_3_ = L_x_3_
x_5 = torch.conv2d(l_x_1_, x_4, l_x_3_); l_x_1_ = x_4 = l_x_3_ = None
return (x_5,)
We can see that x_2 is expected before x_1.
As a result, I can call graph_module[0](x_1, x_2, x_3)
without issue, but calling graph_module_1[0](x_1, x_2, x_3)
results in an error, because it’s getting the inputs in the wrong order.