Get int8 onnx format model

My cnn engine only support activate and weight is all int8 of onnx format,so I must convert torch model to int8 onnx model.But I get error: RuntimeError: quantized::conv(FBGEMM): Expected activation data type QUInt8 but got QInt8 when convert torch to onnx.
my code is:

import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import copy, os
import onnx, onnxsim
from torchvision.models.quantization import resnet18 as QuantizedResNet18
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization import QConfigMapping,  MinMaxObserver
from torch.ao.quantization.backend_config import get_tensorrt_backend_config
from torch.ao.quantization.fake_quantize import FakeQuantize


def save_model(model, model_dir, model_filename):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.save(model.state_dict(), model_filepath)


def save_onnx(model, onnx_path, input_shape=[1,3,224,224]):
    img = torch.rand(*input_shape).float()
    model.eval()
    torch.onnx.export(model.to('cpu'), img, onnx_path, input_names=['input'], output_names=['output'], opset_version=13, dynamic_axes={'input':{0 : '-1'}, 'output':{0 : '-1'}})

    # check onnx model
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    # simplify onnx model
    try:
        print('Starting to simplify ONNX...')
        onnx_model, check = onnxsim.simplify(onnx_model)
        assert check, 'assert check failed'
    except Exception as e:
        print('Simplifier failure:', e)
    onnx.save(onnx_model, onnx_path)


if __name__ == '__main__':
    cuda_device = torch.device("cuda:0")
    cpu_device = torch.device("cpu:0")
    model = resnet18(pretrained=True)
    save_model(model, 'saved_models/', 'resnet18_fp32.pt')

    fused_model = copy.deepcopy(model)
    fused_model.eval()
    fused_model = torch.quantization.fuse_modules(fused_model, [["conv1", "bn1", "relu"]], inplace=True)

    for module_name, module in fused_model.named_children():
        if "layer" in module_name:
            for basic_block_name, basic_block in module.named_children():
                torch.quantization.fuse_modules(
                    basic_block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
                    inplace=True)
                for sub_block_name, sub_block in basic_block.named_children():
                    if sub_block_name == "downsample":
                        torch.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)
    quantized_model = QuantizedResNet18()
    quantized_model = model.load_state_dict(torch.load('saved_models/resnet18_fp32.pt', map_location=cuda_device))
    quantized_model.fuse_model()

    backend_config = get_tensorrt_backend_config()

    from torch.ao.quantization.observer import MinMaxObserver
    qconfig_mapping = QConfigMapping()
    qconfig = torch.ao.quantization.QConfig(
        activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
        weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
    qconfig_mapping.set_global(qconfig)
    dummy_input = torch.randn(1, 3, 224, 224)
    model_prepared = quantize_fx.prepare_qat_fx(quantized_model, qconfig_mapping=qconfig_mapping, example_inputs=dummy_input, backend_config=backend_config)

    quantized_model.to(cpu_device)
    model_quantized = quantize_fx.convert_fx(model_prepared, qconfig_mapping=qconfig_mapping, backend_config=backend_config)
    save_onnx(model_quantized, 'saved_models/tmp.onnx')

    print('finished!')

who can help me?thans very very much!

that is to say, error occurred when save to onnx.

Or is there any method or setting to obtain an onnx model of type int8?

1 Like