Onnx mixed precision slow

Hi, I am trying to export a mixed precision model to onnx. Sadly, the model is much slower when I’m running it in the onnxruntime:

import onnxruntime as ort
from functools import partial
import onnx
import time
import timeit
import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch.utils.benchmark as benchmark


class Net(nn.Module):
    def __init__(self, amp=False):
        super().__init__()
        self.amp = amp
        self._model = resnet18()
        return

    def forward(self, x):
        if self.amp:
            with torch.cuda.amp.autocast():
                x = self._model(x)
        else:
            x = self._model(x)
        return x


def bench_torch(net, x, N):
    with torch.inference_mode():
        _ = net(x)
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(N):
            torch.cuda.synchronize()
            _ = net(x)
            torch.cuda.synchronize()
        t1 = time.time()
        print(f"{((t1-t0)/N)*1000:6.2f}")
    return


def bench_onnx(name, x, N):
    session = ort.InferenceSession(name, providers=['CUDAExecutionProvider'])
    _ = session.run(None, {'input':x})[0]
    t0 = time.time()
    for _ in range(N):
        _ = session.run(None, {'input':x})[0]
    t1 = time.time()
    print(f"{((t1-t0)/N)*1000:6.2f}")
    return

def export(net, name, x):
    torch.onnx.export(net, (x,), name, input_names=['input'], output_names=['logits'], dynamic_axes={'input':{0:'batch'}}, opset_version=15)
    assert onnx.checker.check_model(name, full_check=True) is None, "Check failed"

    print("exported model", name)
    return

net = Net(False).eval().cuda()
net_amp = Net(True).eval().cuda()
x = torch.randn(500, 3, 150, 150).cuda()
xn = x.cpu().numpy()

export(net, 'net.onnx', x)
export(net_amp, 'net_amp.onnx', x)



N = 10
bench_torch(net, x, N) # 117.28
bench_torch(net_amp, x, N) # 56.61
bench_onnx('net.onnx', xn, N) # 134.93
bench_onnx('net_amp.onnx', xn, N) # 2762.18

Is there anything wrong in my method?

System:

  • windows 10
  • rtx 2080
  • pytorch 1.11.0
  • onnxruntime-gpu 1.10.0

Thanks

This is quite important for me. Has really no one an opinion on this?