ONNX export of quantized model

@neginraoof @addisonklinke
In my case torch.quantization.convert creates additional bias with None value for some layers. Though there is no bias there in the full model.

Then during torch.onnx.export torch.jit._unique_state_dict complains about detach() on NoneType as it expects Tensor there.

torch.__version__
1.9.0+cu111

Below is the code to quickly reproduce that:

import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
import numpy as np
import copy

from torch.quantization import QuantStub, DeQuantStub

def print_model(model, tag):

    print(tag.upper(), 'MODEL')
    print(model)
    for item in model.state_dict().items():
        try:
            print(item[0], item[1].shape)
        except:
            print(item[0], item[1])

def check_onnx_export(model, x, tag):

    model.eval()
    print('\nEXPORTING', tag.upper(), 'TO ONNX')
    path = 'tmp/test-{}.onnx'.format(tag)
    torch_output = model(x).detach()
    torch.onnx.export(model, x, path, verbose=True)
    # torch.onnx.export(model, x, path, verbose=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
    print('CHECKING')
    model = onnx.load(path)
    onnx.checker.check_model(model)
    ort_session = ort.InferenceSession(path)
    ort_outputs = ort_session.run(None, {'input.1': np.array(x).astype(np.float32)})
    print(torch_output.shape, ort_outputs[0].shape)
    np.testing.assert_allclose(np.array(torch_output), ort_outputs[0], rtol=1e-03, atol=1e-05)
    print('FINISH')

def fuse(model):

    model_fused = copy.deepcopy(model)
    for m in model_fused.modules():
        if type(m) is Conv:
            torch.quantization.fuse_modules(m, ['conv', 'bn'], inplace=True)
    return model_fused


class Conv(nn.Module):

    def __init__(self):    
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(3, 3, 1, bias=False)
        self.bn = nn.BatchNorm2d(3)
        self.quant = QuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

class Model(nn.Module):

    def __init__(self):
        
        super(Model, self).__init__()
        self.cv1 = Conv()
        self.cv2 = nn.Conv2d(3, 3, 1, bias=False)
        self.dequant = DeQuantStub()        

    def forward(self, x):
        x = self.cv1(x)
        x = self.cv2(x)
        x = self.dequant(x)
        return x 


x = torch.rand(3,3,32,32)
model = Model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
model_fused = fuse(model)
model_qat = torch.quantization.prepare_qat(model_fused)
model_qat.eval()
model_int8 = torch.quantization.convert(model_qat)

print_model(model, 'full')
print_model(model_int8, 'int8')
check_onnx_export(model_int8, x, 'int8')
1 Like