@neginraoof @addisonklinke
In my case torch.quantization.convert
creates additional bias
with None
value for some layers. Though there is no bias
there in the full model.
Then during torch.onnx.export
torch.jit._unique_state_dict
complains about detach()
on NoneType
as it expects Tensor
there.
torch.__version__
1.9.0+cu111
Below is the code to quickly reproduce that:
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
import numpy as np
import copy
from torch.quantization import QuantStub, DeQuantStub
def print_model(model, tag):
print(tag.upper(), 'MODEL')
print(model)
for item in model.state_dict().items():
try:
print(item[0], item[1].shape)
except:
print(item[0], item[1])
def check_onnx_export(model, x, tag):
model.eval()
print('\nEXPORTING', tag.upper(), 'TO ONNX')
path = 'tmp/test-{}.onnx'.format(tag)
torch_output = model(x).detach()
torch.onnx.export(model, x, path, verbose=True)
# torch.onnx.export(model, x, path, verbose=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
print('CHECKING')
model = onnx.load(path)
onnx.checker.check_model(model)
ort_session = ort.InferenceSession(path)
ort_outputs = ort_session.run(None, {'input.1': np.array(x).astype(np.float32)})
print(torch_output.shape, ort_outputs[0].shape)
np.testing.assert_allclose(np.array(torch_output), ort_outputs[0], rtol=1e-03, atol=1e-05)
print('FINISH')
def fuse(model):
model_fused = copy.deepcopy(model)
for m in model_fused.modules():
if type(m) is Conv:
torch.quantization.fuse_modules(m, ['conv', 'bn'], inplace=True)
return model_fused
class Conv(nn.Module):
def __init__(self):
super(Conv, self).__init__()
self.conv = nn.Conv2d(3, 3, 1, bias=False)
self.bn = nn.BatchNorm2d(3)
self.quant = QuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.cv1 = Conv()
self.cv2 = nn.Conv2d(3, 3, 1, bias=False)
self.dequant = DeQuantStub()
def forward(self, x):
x = self.cv1(x)
x = self.cv2(x)
x = self.dequant(x)
return x
x = torch.rand(3,3,32,32)
model = Model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
model_fused = fuse(model)
model_qat = torch.quantization.prepare_qat(model_fused)
model_qat.eval()
model_int8 = torch.quantization.convert(model_qat)
print_model(model, 'full')
print_model(model_int8, 'int8')
check_onnx_export(model_int8, x, 'int8')