Inference in ONNX mixed precision model

Hello,

I trained frcnn model with automatic mixed precision and exported it to ONNX. I wonder however how would inference look like programmaticaly to leverage the speed up of mixed precision model, since pytorch uses with autocast():, and I can’t come with an idea how to put it in the inference engine, like onnxruntime.

My specs:
torch==1.6.0+cu101
torchvision==0.7.0+cu101
onnx==1.7.0
onnxruntime-gpu==1.4.0

Model exports just fine:

torch.onnx.export(model, 
                  x, 
                  "model_16.onnx", 
                  verbose=True, do_constant_folding=True, opset_version=12,
                  input_names=input_names, output_names=output_names)

But I wonder how to leverage mixed precision speed up here:


import onnxruntime as ort
ort_session = ort.InferenceSession('model_16.onnx')
outputs = ort_session.run(None, {'input': x.numpy()})

I’m not exactly sure how ONNX exports the model, but if tracing is used, the mixed-precision operations might have been already recorded. Do you see any FP16 operations, if you profile the ONNX model?

1 Like

Hello,

thanks for the suggestions. I run the onnx runtime profiler on 16FP model from torch, but I’m not exactly sure how to look for execution of FP16 operations there. I’m attaching the log.
http://www.mediafire.com/file/rel0ze3y963nohv/onnxruntime_profile__2020-08-28_10-35-51.json/file
Here’s the code I used to run the profiler:

import onnxruntime as ort
options = ort.SessionOptions()
options.enable_profiling = True
ort_session = ort.InferenceSession('model_16.onnx', options)
outputs = ort_session.run(None, {'input': images[0].cpu().numpy()})
prof_file = ort_session.end_profiling()

Anyway, if I do simple time measurement for inference, I don’t see much difference. To be honest I expected ONNX model to run faster.

1.Pure torch 16FP model:

from imutils.video import FPS
fps = FPS().start()
for i in range(100):
    images = list(image.to('cuda:0') for image in x)
    with autocast():
        pred = model(images)
    
    fps.update()

fps.stop()
print('Time taken: {:.2f}'.format(fps.elapsed()))
print('~ FPS : {:.2f}'.format(fps.fps()))

Time taken: 2.19
~ FPS : 45.57
  1. Torch->ONNX 16FP model:
import onnxruntime as ort
ort_session = ort.InferenceSession('model_16.onnx')
fps = FPS().start()

for i in range(100):
    outputs = ort_session.run(None, {'input': images[0].cpu().numpy()})
    fps.update()

fps.stop()
print('Time taken: {:.2f}'.format(fps.elapsed()))
print('~ FPS : {:.2f}'.format(fps.fps()))

Time taken: 2.15
~ FPS : 46.61