Handling LSTM states as model's inputs/outputs using fx.symbolic_trace

I want to modify a simple model having lstm layers in a way that it receives the states as additional inputs and returns updated states as additional outputs. Indeed, it is required to run model in realtime mode for each new timestep. Generally, I don’t have access to the model’s source code to simply create a new model, also because of other reasons mainly related to operations like datatype changement, view, reshape, etc., I’m obliged to use the graph itself and try to change it accordingly. For a better understanding, here is a basic example:

import torch
import torch.nn as nn
import operator
from torch.fx import symbolic_trace, GraphModule, Node

# Define the original model (since you cannot access the code, pretend it's given like this)
class Test_Model(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Test_Model, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=1, batch_first=True)

    def forward(self, x):
        x, _ = self.lstm(x)  # Original model discards the hidden and cell states
        return x

# Instantiate the model and input
batch_size = 2
seq_len = 5
input_size = 8
hidden_size = 16
model = Test_Model(input_size=input_size, hidden_size=hidden_size)

# Trace the model to create a symbolic representation of the computation graph
traced_graph = symbolic_trace(model)

# Modify the graph to add additional hidden and cell state inputs and outputs
graph = traced_graph.graph


#===================
modified_model = ...
#===================




# Test the modified model
input_tensor = torch.randn(batch_size, seq_len, input_size)  # Example input tensor
h_0 = torch.zeros(1, batch_size, hidden_size)  # (num_layers, batch, hidden_size)
c_0 = torch.zeros(1, batch_size, hidden_size)

# Run one timestep of inference
output, h_n, c_n = modified_model(input_tensor[:, 0:1, :], h_0, c_0)  # Pass single timestep input

In fact, I tried to implement codes like below but it dose not work and having issues with new outputs …

for node in graph.nodes:
    if node.op == "call_module" and isinstance(traced_graph.get_submodule(node.target), nn.LSTM):

        # Modify the LSTM node to accept hidden and cell states as additional inputs
        # Create new placeholder nodes for hidden and cell states
        h_0_node = graph.placeholder("h_0")
        c_0_node = graph.placeholder("c_0")

        # Adjust the arguments of the LSTM node to include h_0 and c_0
        # `node.args` originally contains just (input,)
        node.args = (node.args[0], (h_0_node, c_0_node))

        # Insert hidden state placeholders 
        with graph.inserting_before(node):
          h_node = graph.create_node('placeholder', target="h_0")
          c_node = graph.create_node('placeholder', target="c_0")
          node.args = (node.args[0], (h_node, c_node))

        node.meta['outputs'] = 2  # Set outputs to tuple of output and (new_h, new_c)

        # Adjust outputs to include the hidden states
        with graph.inserting_after(node):
            output_node = graph.call_function(operator.getitem, (node, 0))  # output sequence
            h_n_node = graph.call_function(operator.getitem, (node, 1, 0))  # h_n state
            c_n_node = graph.call_function(operator.getitem, (node, 1, 1))  # c_n state

        # Replace the original output with the separated outputs
        node.replace_all_uses_with(output_node)
        graph.output((output_node, h_n_node, c_n_node))  # Add h_n and c_n as additional outputs

# Recompile the graph into a new `GraphModule`
graph.lint()  # Ensures graph correctness
modified_model = GraphModule(traced_graph, graph)

Thanks in advance for your help.