It seems that after I serialize and load back a quantized model, the output type of quantized operators, QUInt8
, is lost and instead it is replaced by float Tensor
type. See below for a module with a single quantized conv layer.
Before torch.jit.save
graph(%self.1 : __torch__.AnnotatedConvModel,
%X : Float(2, 3, 10, 10)):
...
%input : QUInt8(2, 3, 10, 10) = aten::quantize_per_tensor(%X, %67, %68, %69), scope: __module.quant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
...
%Xq : QUInt8(2, 3, 8, 8) = quantized::conv2d(%input, %71, %74, %77, %80, %81, %82, %83), scope: __module.conv # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
%85 : Float(2, 3, 8, 8) = aten::dequantize(%Xq), scope: __module.dequant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
return (%85)
After torch.jit.load
graph(%self.1 : __torch__.AnnotatedConvModel,
%X.1 : Tensor):
...
%input.1 : Tensor = aten::quantize_per_tensor(%X.1, %9, %10, %11) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
%Xq.1 : Tensor = quantized::conv2d(%input.1, %15, %17, %18, %19, %16, %20, %21) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
...
%24 : Tensor = aten::dequantize(%Xq.1) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
return (%24)
The PyTorch frontend in TVM uses this tensor type information to decide if a torch op is invoked on a quantized tensor. See for example the case of converting adaptive avg pooling, which requires special care for quantized case, but in the Torch IR the same op aten::adaptive_avg_pool2d
appears for both float and quantized input.
Without correct typing, we cannot convert serialized quantized PyTorch models. What happens right now is since Torch tells TVM that input tensor is float type, TVM incorrectly converts some quantized ops into float ops.
A repro script, tested on v1.5
import torch
from torch.quantization import QuantStub, DeQuantStub, default_qconfig
class AnnotatedConvModel(torch.nn.Module):
def __init__(self):
super(AnnotatedConvModel, self).__init__()
self.qconfig = default_qconfig
self.conv = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.dequant(x)
return x
def quantize_model(model, inp):
model.qconfig = default_qconfig
torch.quantization.prepare(model, inplace=True)
model(inp)
torch.quantization.convert(model, inplace=True)
def test_conv():
inp = torch.rand(2, 3, 10, 10)
annotated_conv_model = AnnotatedConvModel()
quantize_model(annotated_conv_model, inp)
trace = torch.jit.trace(annotated_conv_model, inp)
torch._C._jit_pass_inline(trace.graph)
print(trace.graph)
torch.jit.save(trace, "trace.pt")
trace = torch.jit.load("trace.pt")
print(trace.graph)
test_conv()