Vizualizing the computation graph in C++ (minimal example)

A naive and minimal example of printing out the computation graph in libtorch c++.


//
// Created by duane on 10/2/20.
//

#include <torch/torch.h>

int main(int arg, char *argv[]){

    torch::cuda::is_available();
    auto x = torch::randn(3, torch::requires_grad());
    auto z = torch::randn(3, torch::requires_grad());
    auto y = x * 2;
    y = y * 2;
    y = y * z;
    y += z;
    y = y * 2;

    std::cout << "digraph G {" << std::endl;

    std::deque<std::shared_ptr<torch::autograd::Node>> nodes;
    nodes.push_front(y.grad_fn());
    while (!nodes.empty()){
        std::cout << "\"" << nodes.back() << "\"" << " [label=\"" << nodes.back()->name() << "\"]" << std::endl;
        for (const auto &edge : nodes.back()->next_edges()) {
            nodes.push_front(edge.function);
            std::cout << "\"" << nodes.back() << "\"" << " -> " << "\"" << edge.function << "\"" << std::endl;
        }
        nodes.pop_back();
    }

    std::cout << "}" << std::endl;

}

Output

digraph G {
"0x557c2b12b230" [label="MulBackward1"]
"0x557c2b12b230" -> "0x557c2b12b080"
"0x557c2b12b080" [label="AddBackward0"]
"0x557c2b12b080" -> "0x557c2b12abd0"
"0x557c2b12b080" -> "0x557c2b12ae00"
"0x557c2b12abd0" [label="MulBackward0"]
"0x557c2b12abd0" -> "0x557c2b12a660"
"0x557c2b12abd0" -> "0x557c2b12ae00"
"0x557c2b12ae00" [label="torch::autograd::AccumulateGrad"]
"0x557c2b12a660" [label="MulBackward1"]
"0x557c2b12a660" -> "0x557c2a61f300"
"0x557c2b12ae00" [label="torch::autograd::AccumulateGrad"]
"0x557c2a61f300" [label="MulBackward1"]
"0x557c2a61f300" -> "0x557c29ff6f30"
"0x557c29ff6f30" [label="torch::autograd::AccumulateGrad"]
}

Paste into http://www.webgraphviz.com/ to vizualize

4 Likes