Return handle from compiled function from C++ to be used by backend

I’ve created a very simple Dynamo backend that gets executed and replaces some operations with mock_compiled_handle.

def mock_compiled_handle(inputs):
    print("Ran mock_compiled_handle")
    return inputs

def toy_backend(gm, sample_inputs):
    exec_graph = []  # list of compiled functions to run
    xx = sample_inputs[0]
    for n in gm.graph.nodes:
        if n.op == "call_module":
            exec_graph.append(mock_compiled_handle)
            a = gm.get_submodule(n.target)
            xx = a(xx)

    def exec_topt(*args):
        outs = None
        for fun in exec_graph:
            if outs is None:
                outs = fun(args[-1])
            else:
                outs = fun(outs)
        return [outs]

    return exec_topt


torch._dynamo.reset()
fn = torch.compile(Model(), backend=toy_backend, dynamic=True)

t = torch.zeros(5)
out = fn(t)

I now want to do the same thing, but using C++ to create the mocked handle. Using this example, I already managed to invoke a C++ function from Python. Currently, the compile function in C++ looks as follows:

void customCompile(torch::jit::Graph &graph, std::vector<torch::Tensor> &tensors)
{
    std::cout << "Function customCompile executed." << std::endl;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("custom_compile", &customCompile, "Test compile function to be used by Dynamo backend.");
}

and can be invoked as follows:

xx = sample_inputs[0]
a = gm.get_submodule(n.target)
ts_trace = torch.jit.trace(a, xx)
ts_trace = torch.jit.freeze(ts_trace.eval())
test = dynamo_backend.custom_compile(ts_trace.graph, [xx])

Instead of returning void, it should return a handle that can be used (such as the mock_compiled_handle). How can this be achieved?

Any resources or examples regarding simple dynamo backend implementations would also be very useful!

You can use the return type std::function<torch::Tensor &(torch::Tensor &)> , so something like this would work:

std::function<torch::Tensor &(torch::Tensor &)> customCompile(torch::jit::Graph &graph, std::vector<torch::Tensor> &tensors)
{
    return [](torch::Tensor &a) -> torch::Tensor &
    {
        std::cout << "Test handle called." << std::endl;
        return a;
    };
}