Hi, I am trying out jit trace functions to automatically parse the model structure.
Originally, I use
model_output.grad_fn for parsing the structure, but the graph structure seems simpler.
FYI, I am using PyTorch version 0.3.1 (installed with pip) in Python 2.7
Below is a part of the code I wrote:
input_var = Variable(torch.FloatTensor( np.random.random((1, init_shape, init_shape,init_shape)).astype(np.float32))).cuda() trace, out = torch.jit.trace(model, (input_var,)) torch.onnx._optimize_trace(trace) graph = trace.graph()
Once I get the graph, I notice that I can get access to the Nodes and Inputs (I assume they are Nodes containing weights, including the input tensor).
I want to get access to the weights inside the nodes, but I cannot find any instructions related to get the weight values into python.
One workaround can be parse the Nodes, and retrieve the weights from
model.state_dict(), but I think there should be as simpler way.