How to parse torch jit trace output?

For my use case, I realize that I need to programmatically derive for each layer, the settings of the layer (like stride, kernel size , etc for a convolution for example) as well as the number of input and output tensors to the layer so as to build up a graph type representation of the network. I tried a simple conv2d model and it outputs the following trace code using

traced_cell = torch.jit.trace(net, (inputseq,), check_trace=False)
def forward(self,
    input: Tensor) -> Tensor:
  _0 = getattr(self, "0")
  weight = _0.weight
  _1 = torch._convolution(input, weight, _0.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True)
  return _1

this is a string output so parsing here is too tedious. Is there a way from the traced_cell.graph to derive this information about the layers present as well as the tensors to and from a layer so as to make semantic connections (say I dump information in a JSON file represent graph with layers and tensors) between nodes at my end? I need to only know the number of tensors and the name to a layer so that I can form a connection between layers in the network

In the above example, input is the input to forward as well as is used as the input to the conv function. this kind of information is what I need. Also the output of conv is indeed the output of the model as well so _1 is the model output. As well I need to know that this convolution uses kernel size of 1, etc.

Any suggestions please?

1 Like

Any progress on this? There should be a way to get a graph in terms of modules and operations.