How to get the computation graph and pass it to C++?

I am going to estimate the benchmark on every op in PyTorch in C++. I am going to write a function like below:

# Python
graph = somefunction(model) # <---- Get the computation graph
total_graph_cost = predictCost(graph)
//use pybind to expose the C++ function to Python.

//Predict function
Cost PredictFunction(Graph graph) {}
  1. Is there any method in PyTorch to get the computation graph and pass it to C++? (Or totally C++ is acceptable).

  2. How could I access the node in the graph? For example, How could I extract the conv2d node’s input tensor shape?

  3. I know that using TorchScript or torch.fx will get the graph. But there are few questions:

  • These two methods get two different IR. How could I access the IR in C++? For example, torch.fx totally access it by python, is C++ code available to access this IR?
  • How could I access the data-structure in C++? Since this IR may be written in Python (torch.fx case).

Hi @Stonepia,

Thanks for the question!

First, I’d like to understand better what you intend to do here. Are you trying to calculate a quantity like FLOPs for the network? One way to do this is with the profiler: Automatic differentiation package - torch.autograd — PyTorch 1.8.0 documentation. It seems to do much of what you’re trying to achieve, however its interface is not in C++ but rather Python.

To answer your direct questions:

  1. I think the most out-of-the-box way to do this is to use TorchScript. For example, you can use tracing and a Pybind API that takes torch::jit::Module like so:
# some_cpp_code.cpp
  m.def("_test_fn", [](const torch::jit::Module& mod) {
    auto forward_method = mod.get_method("forward");
    for (Node * n : forward_method.graph()->nodes()) {
# Python code for testing
import torch
import torchvision.models as models

# Construct and trace ResNet18 for testing
rn18 = models.resnet18()
traced_rn18 = torch.jit.trace(rn18, (torch.randn(5, 3, 224, 224),))

# Inline method calls. This allows us to more easily traverse the graph
# of operations
# Pass our module into the custom C++ function. Note that we use `._c` to
# retrieve the C++ version, not the Python version

Running that, you will see print-outs of all the nodes in the model.

  1. When using tracing, the shapes of values are recorded, you can look at them by using type() on each node. However, note that this is specific to the invocation that was used to record the graph
  2. torch.fx is exclusively Python. You can define a C++ structure isomorphic to the FX graph, but it will probably be more work. We don’t have plans to generalize FX to C++ at this point


@James_Reed Thank you very much for your answer, you saved my life!

Yes, I want to do something similar like profiler. However, the profiler requires to truly run the model, while I just want to get a rough prediction of the cost based on its input/output size (just as tf’s grappler does). Then using torch’s profiler along seems not working.

I want to implement a custom cost estimator based on every op , and finally get the final model cost.

In this way, I want to get the graph and use the code like your suggestion:

Costs totalcost ;
for (Node * n : forward_method.graph()->nodes()) {
     totalcost += predictCost(n);

Thank you again, your reply really helps!