Speed up loading traced network

Hello All,

Is there a way to speed up loading the traced network in C++ (Windows 10)?

In python, I first need to run a network and based on its result decide whether to do the second one or not. So my tracing code looks like this:

Class MySampleNetwork(nn.Module):
    def __init__(self, params):
        super(MySampleNetwork, self).__init__()
        # trace network1
        sample_data1 = torch.rand(size)
        sample_data1 = sample_data1.to(ddata.gpu, non_blocking=True)
        net1 = load_net(net1).eval().to(ddata.gpu)
        self.net1 = torch.jit.trace(net1, sample_data1)
        # trace network2
        sample_data2 = torch.rand(size)
        sample_data2 = sample_data2.to(ddata.gpu, non_blocking=True)
        net2 = load_net(net2).eval().to(ddata.gpu)
        self.net1 = torch.jit.trace(net2, sample_data2)

    def forward(self, net1_input, net2_input):
        net1_out = self.net1(net1_input) 
        net2_out = torch.zeros(expected_out_size)

        if condistion based on net1_out:
            net2_out = self.net2(net2_input)
        
        # combine the results
        result = torch.zeros(net1_out_size+net2_out_size)
        result[:net1_out_size] = net1_out[0]
        result[net1_out_size:] = net2_out[0]

        return result

traced_net = torch.jit.script(MySampleNetwork(params))

The saved network is about 26M.

In C++, I load the network

torch::jit::script::Module module = torch::jit::load(net_path())

It takes more than 30 seconds to load the network. Is there anything I can do to reduce the time?

Thank you.

Try this:

torch::Device device(torch::kCUDA);
const std::string modelName = "candy_cpp.pt";
auto module = torch::jit::load(modelName, device);

How much time does this take to load?

@dambo Thank you for your response. Adding device does not make it faster.
The problem is in if condition. Once I remove it, it loads in about 2-4 seconds. For now, I removed the if, so I’m computing the whole thing even though I sometimes do not need it. But would need to solve this issue if anyone can help.

Have you added print statements to see where exactly its happening? I would time how long it takes to load the net(assuming you are loading both at the same time). My suspicion is that since you are trying to trace both models in the constructor, this slows things down.

There really isnt a need to trace/script the network a total of 3 times. I would remove the tracing in the constructor, and just script the whole graph at once. You only want to script the network when it will be used outside of a Python environment and scripting the network at the end like you have is all you need to achieve this.