How to get graph IR from jit::script::module

I want to implement a c++ frontend which convert pytorch IRs to IRs of myself. In torch/csrc/jit/ir.h, I find the class torch::jit::Graph. Now I can load script::module from model file.
Is there any function to translate jit::scirpt::module to jit::script::module.

I read the file jit/script/module.h, maybe I can use module.get_method(“forward”).graph().
Anybody know about how?

you can do something like this:

import torch

def foo(x, y):
    return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

print(traced_foo.graph)

and I got the graph like this

graph(%x : Float(3)
      %y : Float(3)) {
  %2 : Long() = prim::Constant[value={2}]()
  %3 : Float(3) = aten::mul(%x, %2)
  %4 : int = prim::Constant[value=1]()
  %5 : Float(3) = aten::add(%3, %y, %4)
  return (%5);
}
2 Likes

Just wonder what is the usage of IR graph? Thanks for clarification.