How to get the computation graph and pass it to C++?

Hi @Stonepia,

Thanks for the question!

First, I’d like to understand better what you intend to do here. Are you trying to calculate a quantity like FLOPs for the network? One way to do this is with the profiler: Automatic differentiation package - torch.autograd — PyTorch 1.8.0 documentation. It seems to do much of what you’re trying to achieve, however its interface is not in C++ but rather Python.

To answer your direct questions:

  1. I think the most out-of-the-box way to do this is to use TorchScript. For example, you can use tracing and a Pybind API that takes torch::jit::Module like so:
# some_cpp_code.cpp
  <...>
  m.def("_test_fn", [](const torch::jit::Module& mod) {
    auto forward_method = mod.get_method("forward");
    assert(forward_method.graph());
    for (Node * n : forward_method.graph()->nodes()) {
      n->dump();
    }
  });
# Python code for testing
import torch
import torchvision.models as models

# Construct and trace ResNet18 for testing
rn18 = models.resnet18()
traced_rn18 = torch.jit.trace(rn18, (torch.randn(5, 3, 224, 224),))

# Inline method calls. This allows us to more easily traverse the graph
# of operations
torch._C._jit_pass_inline(traced_rn18.graph)
# Pass our module into the custom C++ function. Note that we use `._c` to
# retrieve the C++ version, not the Python version
_test_fn(traced_rn18._c)

Running that, you will see print-outs of all the nodes in the model.

  1. When using tracing, the shapes of values are recorded, you can look at them by using type() on each node. However, note that this is specific to the invocation that was used to record the graph
  2. torch.fx is exclusively Python. You can define a C++ structure isomorphic to the FX graph, but it will probably be more work. We don’t have plans to generalize FX to C++ at this point

James