Hello All,
Is there a way to speed up loading the traced network in C++ (Windows 10)?
In python, I first need to run a network and based on its result decide whether to do the second one or not. So my tracing code looks like this:
Class MySampleNetwork(nn.Module):
def __init__(self, params):
super(MySampleNetwork, self).__init__()
# trace network1
sample_data1 = torch.rand(size)
sample_data1 = sample_data1.to(ddata.gpu, non_blocking=True)
net1 = load_net(net1).eval().to(ddata.gpu)
self.net1 = torch.jit.trace(net1, sample_data1)
# trace network2
sample_data2 = torch.rand(size)
sample_data2 = sample_data2.to(ddata.gpu, non_blocking=True)
net2 = load_net(net2).eval().to(ddata.gpu)
self.net1 = torch.jit.trace(net2, sample_data2)
def forward(self, net1_input, net2_input):
net1_out = self.net1(net1_input)
net2_out = torch.zeros(expected_out_size)
if condistion based on net1_out:
net2_out = self.net2(net2_input)
# combine the results
result = torch.zeros(net1_out_size+net2_out_size)
result[:net1_out_size] = net1_out[0]
result[net1_out_size:] = net2_out[0]
return result
traced_net = torch.jit.script(MySampleNetwork(params))
The saved network is about 26M.
In C++, I load the network
torch::jit::script::Module module = torch::jit::load(net_path())
It takes more than 30 seconds to load the network. Is there anything I can do to reduce the time?
Thank you.