Hi, I am trying out jit trace functions to automatically parse the model structure.
Originally, I use model_output.grad_fn
for parsing the structure, but the graph structure seems simpler.
FYI, I am using PyTorch version 0.3.1 (installed with pip) in Python 2.7
Below is a part of the code I wrote:
input_var = Variable(torch.FloatTensor( np.random.random((1, init_shape[0], init_shape[1],init_shape[2])).astype(np.float32))).cuda()
trace, out = torch.jit.trace(model, (input_var,))
torch.onnx._optimize_trace(trace)
graph = trace.graph()
Once I get the graph, I notice that I can get access to the Nodes and Inputs (I assume they are Nodes containing weights, including the input tensor).
I want to get access to the weights inside the nodes, but I cannot find any instructions related to get the weight values into python.
One workaround can be parse the Nodes, and retrieve the weights from model.state_dict()
, but I think there should be as simpler way.