My cnn engine only support activate and weight is all int8 of onnx format,so I must convert torch model to int8 onnx model.But I get error: RuntimeError: quantized::conv(FBGEMM): Expected activation data type QUInt8 but got QInt8 when convert torch to onnx.
my code is:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import copy, os
import onnx, onnxsim
from torchvision.models.quantization import resnet18 as QuantizedResNet18
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization import QConfigMapping, MinMaxObserver
from torch.ao.quantization.backend_config import get_tensorrt_backend_config
from torch.ao.quantization.fake_quantize import FakeQuantize
def save_model(model, model_dir, model_filename):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filepath = os.path.join(model_dir, model_filename)
torch.save(model.state_dict(), model_filepath)
def save_onnx(model, onnx_path, input_shape=[1,3,224,224]):
img = torch.rand(*input_shape).float()
model.eval()
torch.onnx.export(model.to('cpu'), img, onnx_path, input_names=['input'], output_names=['output'], opset_version=13, dynamic_axes={'input':{0 : '-1'}, 'output':{0 : '-1'}})
# check onnx model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
# simplify onnx model
try:
print('Starting to simplify ONNX...')
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print('Simplifier failure:', e)
onnx.save(onnx_model, onnx_path)
if __name__ == '__main__':
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")
model = resnet18(pretrained=True)
save_model(model, 'saved_models/', 'resnet18_fp32.pt')
fused_model = copy.deepcopy(model)
fused_model.eval()
fused_model = torch.quantization.fuse_modules(fused_model, [["conv1", "bn1", "relu"]], inplace=True)
for module_name, module in fused_model.named_children():
if "layer" in module_name:
for basic_block_name, basic_block in module.named_children():
torch.quantization.fuse_modules(
basic_block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]],
inplace=True)
for sub_block_name, sub_block in basic_block.named_children():
if sub_block_name == "downsample":
torch.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)
quantized_model = QuantizedResNet18()
quantized_model = model.load_state_dict(torch.load('saved_models/resnet18_fp32.pt', map_location=cuda_device))
quantized_model.fuse_model()
backend_config = get_tensorrt_backend_config()
from torch.ao.quantization.observer import MinMaxObserver
qconfig_mapping = QConfigMapping()
qconfig = torch.ao.quantization.QConfig(
activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
qconfig_mapping.set_global(qconfig)
dummy_input = torch.randn(1, 3, 224, 224)
model_prepared = quantize_fx.prepare_qat_fx(quantized_model, qconfig_mapping=qconfig_mapping, example_inputs=dummy_input, backend_config=backend_config)
quantized_model.to(cpu_device)
model_quantized = quantize_fx.convert_fx(model_prepared, qconfig_mapping=qconfig_mapping, backend_config=backend_config)
save_onnx(model_quantized, 'saved_models/tmp.onnx')
print('finished!')
who can help me?thans very very much!