How to get positional order of inputs and outputs for a graph that was exported via torch.export?

I am trying to find the positional order of “placeholder” and “output” nodes of a torch FX graph generated by torch.export. Here is an example:

import torch
import torch.nn as nn
import torch.export

# author a model.
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.relu1 = nn.ReLU()
        self.fc1 = nn.Linear(32, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
    def forward(self, y, x):
        layer11 = self.relu1(self.fc2(x))
        layer10 = self.relu1(self.fc1(y))
        return layer10, layer11
    def name(self):
        return "MLP"

model = MLPNet().eval()

# generate a uniform distribution of data.
n_batches = 100
# generate some example input.
x_in = torch.distributions.uniform.Uniform(-1, 1).sample([n_batches, 64, 32])
y_in = torch.distributions.uniform.Uniform(-1, 1).sample([n_batches, 64, 64])

# export the module 
m_export = torch.export.export(model, (x_in[0,:], y_in[0,:]))

This prints the following:

opcode         name        target              args                kwargs
-------------  ----------  ------------------  ------------------  --------
get_attr       fc2_weight  fc2.weight          ()                  {}
get_attr       fc2_bias    fc2.bias            ()                  {}
get_attr       fc1_weight  fc1.weight          ()                  {}
get_attr       fc1_bias    fc1.bias            ()                  {}
placeholder    y           y                   ()                  {}
placeholder    x           x                   ()                  {}
call_function  t           aten.t.default      (fc2_weight,)       {}
call_function  addmm       aten.addmm.default  (fc2_bias, x, t)    {}
call_function  relu        aten.relu.default   (addmm,)            {}
call_function  t_1         aten.t.default      (fc1_weight,)       {}
call_function  addmm_1     aten.addmm.default  (fc1_bias, y, t_1)  {}
call_function  relu_1      aten.relu.default   (addmm_1,)          {}
output         output_1    output              ((relu_1, relu),)   {}

The input nodes are (y, x) - in that order, and the output nodes are (relu_1, relu_0) in that order.

I am following this approach (using torch.export.graph_signature) to make an association with the input argument order (as in the original module signature), and the corresponding “placeholder” and “output” nodes.

# get the input argments of the exported module in positional order.
from torch.export.graph_signature import InputKind, OutputKind
in_args = [ 
        for spec in m_export.graph_signature.input_specs
        if spec.kind == InputKind.USER_INPUT
# for each input argument find the corresponding node in the exported graph.
graph = m_export.module().graph
for idx, arg in enumerate(in_args):
    print(f'Node for Input Argument #{idx}: ')
    node = graph.find_nodes(op='placeholder', target=arg)

# similarly print output nodes in positional order.
out_args = [
        for spec in m_export.graph_signature.output_specs
        if spec.kind == OutputKind.USER_OUTPUT
for idx, arg in enumerate(out_args):
    print(f'Node for Output Argument #{idx}: ')
    node = [ n for n in graph.nodes 
            if n.op == 'call_function' and == arg ]

This prints the following which seems correct on the first glace -

Node for Input Argument #0:
Node for Input Argument #1:
Node for Output Argument #0:
Node for Output Argument #1:

My question is whether or not this approach is reliable for all types of graphs, and/or, if there is a simpler approach available to accomplish the same thing?

Thank You!