I have a simple PyTorch model that has a few layers, and I want to trace its graph and export it as a set of nodes and edges (as a traditional graph). Furthermore, I want to export the hyper-parameters for each layers (such as such as a Conv2D layer’s number of input channels and output channels etc…) as well as the parameters (such as a Conv2D layer’s kernels and bias).
- Is it possible to do this using torch.jit.trace (or torch.jit._get_trace_graph) ?
- Is there any other way I can trace the graph other than torch.jit (for more complex models such as ResNet or DenseNet) ?
Thank you !