What is torch.jit.trace(model, input) doing?

Hi,
I am running a model trained in python in C++ and am trying to speed it up as much as possible. Now I figured out that scripting the model with torch.jit.trace(model, input) is slowing down the prediction, and I have no idea why. Here is my test code:

model = Model(128, 128, 19**2, 0.1)
input = torch.tensor(torch.ones(1,1,128,128))

torch.cuda.synchronize()
torch.set_num_threads(1)
model.eval()

start_time1 = time.time()
output1 = model(input)
print('Time for total prediction = {0}'.format(time.time()-start_time1))

# Trace the model and convert the functionality to script
traced_model = torch.jit.trace(model, input)

start_time2 = time.time()
output2 = traced_model(input)
print('Time for total prediction = {0}'.format(time.time()-start_time2))

And the output is:

Time for convolutionals layers = 0.005982398986816406
Time for hidden layers = 0.001995086669921875
Time for total prediction = 0.008975505828857422
Time for convolutionals layers = 0.007979154586791992
Time for hidden layers = 0.00399017333984375
Time for convolutionals layers = 0.009974002838134766
Time for hidden layers = 0.002992391586303711
Time for convolutionals layers = 0.005983114242553711
Time for hidden layers = 0.0029931068420410156
Time for total prediction = 0.015474081039428711
  1. Why does torch.jit.trace(model, input) run the forward() method three times?
  2. Why is the scripted forward() method much slower?
  3. Loaded in C++ the prediction even takes 0.03s, so it is again much slower than these values here. Is there a way to make it faster? (as fast as in python)
2 Likes