Getting input and output tensors for each node in trace graph

I wish to do forward and backward execution in ScriptModule and then retrieve input and output tensors for each node in execution graph. Is this possible?

I’ve code like below:

    trace, out = torch.jit.get_trace_graph(model, args)
    torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
    torch_graph = trace.graph()

    for node in torch_graph.nodes():
      # How to get input and output tensor for this node?
      # Additionally is it possible to get parameters (weight/bias) for this node?

Could you explain a bit about what you’re trying to accomplish? Generally we don’t recommend that people use the python IR bindings as the API is not stable and you can get into some C+±y trouble. Thanks!