Output tensor type is lost after serializing and loading back a quantized model

@jerryzh168 @raghuramank100

It seems that after I serialize and load back a quantized model, the output type of quantized operators, QUInt8, is lost and instead it is replaced by float Tensor type. See below for a module with a single quantized conv layer.

Before torch.jit.save

graph(%self.1 : __torch__.AnnotatedConvModel,
      %X : Float(2, 3, 10, 10)):
  ...
  %input : QUInt8(2, 3, 10, 10) = aten::quantize_per_tensor(%X, %67, %68, %69), scope: __module.quant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
  ...
  %Xq : QUInt8(2, 3, 8, 8) = quantized::conv2d(%input, %71, %74, %77, %80, %81, %82, %83), scope: __module.conv # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
  %85 : Float(2, 3, 8, 8) = aten::dequantize(%Xq), scope: __module.dequant # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
  return (%85)

After torch.jit.load

graph(%self.1 : __torch__.AnnotatedConvModel,
      %X.1 : Tensor):
  ...
  %input.1 : Tensor = aten::quantize_per_tensor(%X.1, %9, %10, %11) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:43:0
  %Xq.1 : Tensor = quantized::conv2d(%input.1, %15, %17, %18, %19, %16, %20, %21) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:215:0
  ...
  %24 : Tensor = aten::dequantize(%Xq.1) # /home/masa/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
  return (%24)

The PyTorch frontend in TVM uses this tensor type information to decide if a torch op is invoked on a quantized tensor. See for example the case of converting adaptive avg pooling, which requires special care for quantized case, but in the Torch IR the same op aten::adaptive_avg_pool2d appears for both float and quantized input.

Without correct typing, we cannot convert serialized quantized PyTorch models. What happens right now is since Torch tells TVM that input tensor is float type, TVM incorrectly converts some quantized ops into float ops.

A repro script, tested on v1.5

import torch
from torch.quantization import QuantStub, DeQuantStub, default_qconfig


class AnnotatedConvModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvModel, self).__init__()
        self.qconfig = default_qconfig
        self.conv = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x


def quantize_model(model, inp):
    model.qconfig = default_qconfig
    torch.quantization.prepare(model, inplace=True)
    model(inp)
    torch.quantization.convert(model, inplace=True)


def test_conv():
    inp = torch.rand(2, 3, 10, 10)
    annotated_conv_model = AnnotatedConvModel()
    quantize_model(annotated_conv_model, inp)

    trace = torch.jit.trace(annotated_conv_model, inp)
    torch._C._jit_pass_inline(trace.graph)
    print(trace.graph)

    torch.jit.save(trace, "trace.pt")
    trace = torch.jit.load("trace.pt")
    print(trace.graph)


test_conv()

Also posted on Github https://github.com/pytorch/pytorch/issues/39690

What if you run the graph on some sample data after you load the model?

You can use the type, but don’t rely on the shape since it will probably change every time you run the model with input of different shape

def test_conv():
    inp = torch.rand(2, 3, 10, 10)
    annotated_conv_model = AnnotatedConvModel()
    quantize_model(annotated_conv_model, inp)

    trace = torch.jit.trace(annotated_conv_model, inp)
    torch._C._jit_pass_inline(trace.graph)
    print(trace.graph)

    torch.jit.save(trace, "trace.pt")
    loaded = torch.jit.load("trace.pt")

    for i in range(5):
        out = loaded(torch.rand(2, 3, 10, 10))

    print(loaded.graph)

Tried running a loaded graph with some inputs, it still says

%Xq.1 : Tensor = quantized::conv2d(...)

@jerryzh168 COrrect me if I am wrong, but I think that’s what jit does irrespective of it being quantized or not. I believe we should talk to the JIT team to somehow allow dtype to be exposed.