I want to implement a c++ frontend which convert pytorch IRs to IRs of myself. In torch/csrc/jit/ir.h, I find the class torch::jit::Graph. Now I can load script::module from model file.
Is there any function to translate jit::scirpt::module to jit::script::module.
I read the file jit/script/module.h, maybe I can use module.get_method(“forward”).graph().
Anybody know about how?
you can do something like this:
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
print(traced_foo.graph)
and I got the graph like this
graph(%x : Float(3)
%y : Float(3)) {
%2 : Long() = prim::Constant[value={2}]()
%3 : Float(3) = aten::mul(%x, %2)
%4 : int = prim::Constant[value=1]()
%5 : Float(3) = aten::add(%3, %y, %4)
return (%5);
}
2 Likes
Just wonder what is the usage of IR graph? Thanks for clarification.