Speed up loading traced network

Hello All,

Is there a way to speed up loading the traced network in C++ (Windows 10)?

In python, I first need to run a network and based on its result decide whether to do the second one or not. So my tracing code looks like this:

Class MySampleNetwork(nn.Module):
    def __init__(self, params):
        super(MySampleNetwork, self).__init__()
        # trace network1
        sample_data1 = torch.rand(size)
        sample_data1 = sample_data1.to(ddata.gpu, non_blocking=True)
        net1 = load_net(net1).eval().to(ddata.gpu)
        self.net1 = torch.jit.trace(net1, sample_data1)
        # trace network2
        sample_data2 = torch.rand(size)
        sample_data2 = sample_data2.to(ddata.gpu, non_blocking=True)
        net2 = load_net(net2).eval().to(ddata.gpu)
        self.net1 = torch.jit.trace(net2, sample_data2)

    def forward(self, net1_input, net2_input):
        net1_out = self.net1(net1_input) 
        net2_out = torch.zeros(expected_out_size)

        if condistion based on net1_out:
            net2_out = self.net2(net2_input)
        # combine the results
        result = torch.zeros(net1_out_size+net2_out_size)
        result[:net1_out_size] = net1_out[0]
        result[net1_out_size:] = net2_out[0]

        return result

traced_net = torch.jit.script(MySampleNetwork(params))

The saved network is about 26M.

In C++, I load the network

torch::jit::script::Module module = torch::jit::load(net_path())

It takes more than 30 seconds to load the network. Is there anything I can do to reduce the time?

Thank you.

Try this:

torch::Device device(torch::kCUDA);
const std::string modelName = "candy_cpp.pt";
auto module = torch::jit::load(modelName, device);

How much time does this take to load?