Running cpp torch::jit::Graph

Hey, I’m quite stuck on trying to run torch::jit::Graph with cpp API, I have already tried to search for docs or implementations but failed.
my main purpose to do so is to divide a JIT Graph into subgraphs (I would like to control which nodes go to each subgraph). My plan is to parse a graph into nodes and then create new graphs from each sequence of nodes (with my own decision) and then run each graph one by one.

this is what I have already tried:

#include <torch/script.h>
#include <iostream>

int main() {
    torch::jit::script::Module my_model = torch::jit::load("<MY_PATH>/model.pkl");

    torch::jit::script::Method m = my_model.get_method("forward");
    auto g = m.graph();

    auto cu = std::make_shared<torch::jit::script::CompilationUnit>();
    c10::QualifiedName name("forward");
    torch::Function *fn = cu->create_function(std::move(name), g);

    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::rand({1, 1, 3, 3}));

    auto output = fn->operator()(inputs);
    std::cout << output << std::endl;
}

based on model.pkl that was produced by the python code below:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

scripted_foo = torch.jit.script(Net())
scripted_foo.save('model.pkl')

this cpp approch produces the following error:

libc++abi.dylib: terminating with uncaught exception of type c10::Error: forward() Expected a value of type '__torch__.Net' for argument 'self' but instead found type 'Tensor'.
Position: 0
Declaration: forward(__torch__.Net self, Tensor x) -> (Tensor) (checkArg at ../aten/src/ATen/core/function_schema_inl.h:194)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x1038ce787 in libc10.dylib)
frame #1: c10::FunctionSchema::checkArg(c10::IValue const&, c10::Argument const&, c10::optional<unsigned long>) const + 719 (0x11285ee6f in libtorch.dylib)
frame #2: c10::FunctionSchema::checkAndNormalizeInputs(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, c10::IValue, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const, c10::IValue> > > const&) const + 228 (0x11285ddf4 in libtorch.dylib)
frame #3: torch::jit::Function::operator()(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, c10::IValue, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > >, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const, c10::IValue> > > const&) + 49 (0x11285d221 in libtorch.dylib)
frame #4: main + 686 (0x1037eb59e in deep_project)
frame #5: start + 1 (0x7fff700f0cc9 in libdyld.dylib)
frame #6: 0x0 + 1 (0x1 in ???)

Process finished with exit code 6

I assume that I need to push self of type __torch__.Net somehow but I dont know how, I will be happy for some help.
thanks a lot.

1 Like

Can you post this under ‘jit’ category?

1 Like

For methods like, as in Python, the first argument to the graph is self, which represents the module object instance. The Module API takes care of adding self to the stack, see here.

2 Likes

thanks a lot, it works - I have another question though, do you know what is the most clean way to create a graph from a couple of nodes ?

Probably using the base Graph API for node creation and insertion is the cleanest. For an example of how to take “foreign” Node*s not owned by a graph and copy them into the graph, the Graph::copy method is a useful example: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.cpp#L680

2 Likes