I wish to do forward and backward execution in ScriptModule and then retrieve input and output tensors for each node in execution graph. Is this possible?
I’ve code like below:
trace, out = torch.jit.get_trace_graph(model, args)
out.backward(gradient=torch.ones(out.size()))
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
torch_graph = trace.graph()
for node in torch_graph.nodes():
# How to get input and output tensor for this node?
# Additionally is it possible to get parameters (weight/bias) for this node?