Quantizing to int8 without stubs for input and output?

Hi,

I want to quantize a model so that I can run it without the quantization stubs and just pass in directly int8. I followed some of the tutorials and previous discussions on this forum. Here is the current code I use to experiment with features:


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x


# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

example_inputs = torch.rand(1,1,1,1)
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)

prepare_custom_config = PrepareCustomConfig() 
prepare_custom_config.set_input_quantized_indexes([0])  
prepare_custom_config.set_output_quantized_indexes([0])


prepared_model = prepare_fx(model_fp32, qconfig_mapping, example_inputs, prepare_custom_config, backend_config=backend_config)


def calibrate(model):
    model.eval()
    with torch.no_grad():
        for i in range(10):
            rand_in = torch.rand(1,1,1,1)
            model(rand_in)

calibrate(prepared_model)

quantized_model = convert_fx(prepared_model)

example_inputs_int8 = torch.randint(-122, 123, (1,1,1,1), dtype=torch.int8)
res = prepared_model(example_inputs_int8)

When I try to run this I get this error:
RuntimeError: Input type (signed char) and bias type (float) should be the same

It seems like the biases are not converted. What do I need to pass here to fix this? Thanks in advance!

That’s not going to work because final coverted quant stub and dequant stub aren’t giving/recieving int8’s but quint8’s. Quant stub does torch.quantize_per_tensor and Dequant stub does tensor.dequantize

Your best bet is to apply quantization as normal, and then you can change self.quant to nn.identity and pass in a quint8 dtype. Or convert self.quant into something that takes in int8 data and outputs quint8 data by doing something like

(x_int8-zero_point)*scale->x_fp32

torch.quantize_per_tensor(x_fp32, scale, zero_point, torch.quint8)

Is there a way to do this entirely without the stubs? Really the only thing I care about is the model receives and produces quantized values. I’m assuming that these operations are handled separately.

Just don’t use the stubs and then do normal eager mode quantization.

But your quantized values need to have dtype quint8 not int8

You can refer to this for your desire. Looks possible with FX graph mode.

according to the pytorch quantization doc, it seems using stubs is a must for static PTQ, do you have any code snippet or links to show a static PTQ without QuantStub?

you do it normally but without the stubs.

The stubs are there to tell the quantization api where to convert from float to qint8 and back. You can just run the quantizer without the stubs and it will quantize the modules and if you try to run it normally it will error because it will be passing float inputs to quantized modules that need qint8 inputs.