A naive and minimal example of printing out the computation graph in libtorch c++.
//
// Created by duane on 10/2/20.
//
#include <torch/torch.h>
int main(int arg, char *argv[]){
torch::cuda::is_available();
auto x = torch::randn(3, torch::requires_grad());
auto z = torch::randn(3, torch::requires_grad());
auto y = x * 2;
y = y * 2;
y = y * z;
y += z;
y = y * 2;
std::cout << "digraph G {" << std::endl;
std::deque<std::shared_ptr<torch::autograd::Node>> nodes;
nodes.push_front(y.grad_fn());
while (!nodes.empty()){
std::cout << "\"" << nodes.back() << "\"" << " [label=\"" << nodes.back()->name() << "\"]" << std::endl;
for (const auto &edge : nodes.back()->next_edges()) {
nodes.push_front(edge.function);
std::cout << "\"" << nodes.back() << "\"" << " -> " << "\"" << edge.function << "\"" << std::endl;
}
nodes.pop_back();
}
std::cout << "}" << std::endl;
}
Output
digraph G {
"0x557c2b12b230" [label="MulBackward1"]
"0x557c2b12b230" -> "0x557c2b12b080"
"0x557c2b12b080" [label="AddBackward0"]
"0x557c2b12b080" -> "0x557c2b12abd0"
"0x557c2b12b080" -> "0x557c2b12ae00"
"0x557c2b12abd0" [label="MulBackward0"]
"0x557c2b12abd0" -> "0x557c2b12a660"
"0x557c2b12abd0" -> "0x557c2b12ae00"
"0x557c2b12ae00" [label="torch::autograd::AccumulateGrad"]
"0x557c2b12a660" [label="MulBackward1"]
"0x557c2b12a660" -> "0x557c2a61f300"
"0x557c2b12ae00" [label="torch::autograd::AccumulateGrad"]
"0x557c2a61f300" [label="MulBackward1"]
"0x557c2a61f300" -> "0x557c29ff6f30"
"0x557c29ff6f30" [label="torch::autograd::AccumulateGrad"]
}
Paste into http://www.webgraphviz.com/ to vizualize