Understanding quantized linear layer

I’m trying to understand the implementation of the quantized linear layer with fbgemm.
For this I’m trying to reproduce the result in python for a simple linear layer without bias, but have failed to do so.

For my implementation I have looked at following files:

The function I use to compute the quantized linear layer is the following:

def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
                       scale_out, zero_point_out):
    """
    Implementation of a quantized linear layer without bias.

    :param x: quantized input
    :param scale_x: scale of quantized input
    :param zp_x: zero point of quantized input
    :param weight_data: quantized weight
    :param weight_scale: scale for quantized weight
    :param weight_zero_point: zero_point for quantized weight
    :param scale_out: scale of quantized output
    :param zero_point_out: zero point of quantized output
    :return: requantized output of linear layer
    """

    # Preparing input by shifting
    X = x.float() - zp_x - weight_zero_point

    # Compute output rescaler
    act_times_w_scale = scale_x * weight_scale
    output_multiplier_float = act_times_w_scale / scale_out

    x = (torch.nn.functional.linear(X, weight=weight_data.float()) * output_multiplier_float).round() + zero_point_out
    return torch.max(torch.tensor(0), torch.min(torch.tensor(255), x))

Does this look correct? This function does not give the same result than pytorch.

The following code fails with: AssertionError: tensor([[ 0., 255.]]) is not equal to tensor([[ 59., 253.]]).

from collections import namedtuple

import torch.nn
from torch.ao.quantization.observer import MinMaxObserver
import torch
import torch.nn.functional as F
import random
import numpy as np

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])


class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.fc = torch.nn.Linear(2, 2, bias=False)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inp):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        inp = self.quant(inp)
        x = self.fc(inp)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x


def quantize_tensor_unsigned(x, scale, zero_point, num_bits=8):
    qmin = 0.
    qmax = 2. ** num_bits - 1.

    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()

    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)


def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
                       scale_out, zero_point_out):
    """
    Implementation of a quantized linear layer without bias.

    :param x: quantized input
    :param scale_x: scale of quantized input
    :param zp_x: zero point of quantized input
    :param weight_data: quantized weight
    :param weight_scale: scale for quantized weight
    :param weight_zero_point: zero_point for quantized weight
    :param scale_out: scale of quantized output
    :param zero_point_out: zero point of quantized output
    :return: requantized output of linear layer
    """

    # Preparing input by shifting
    X = x.float() - zp_x - weight_zero_point

    # Compute output rescaler
    act_times_w_scale = scale_x * weight_scale
    output_multiplier_float = act_times_w_scale / scale_out

    x = (torch.nn.functional.linear(X, weight=weight_data.float()) * output_multiplier_float).round() + zero_point_out
    return torch.max(torch.tensor(0), torch.min(torch.tensor(255), x))


def main():
    # create a model instance
    model_fp32 = M()

    # model must be set to eval mode for static quantization logic to work
    model_fp32.eval()

    # attach a global qconfig, which contains information about what kind
    # of observers to attach. Use 'fbgemm' for server inference and
    # 'qnnpack' for mobile inference. Other quantization configurations such
    # as selecting symmetric or assymetric quantization and MinMax or L2Norm
    # calibration techniques can be specified here.
    model_fp32.qconfig = torch.quantization.QConfig(
        activation=MinMaxObserver.with_args(dtype=torch.quint8),
        weight=MinMaxObserver.with_args(dtype=torch.qint8)
    )

    # Prepare the model for static quantization. This inserts observers in
    # the model that will observe activation tensors during calibration.
    model_fp32_prepared = torch.quantization.prepare(model_fp32)

    # calibrate the prepared model to determine quantization parameters for activations
    # in a real world setting, the calibration would be done with a representative dataset
    input_fp32 = torch.randn(1, 2)
    model_fp32_prepared(input_fp32)

    # Convert the observed model to a quantized model. This does several things:
    # quantizes the weights, computes and stores the scale and bias value to be
    # used with each activation tensor, and replaces key operators with quantized
    # implementations.
    model_int8 = torch.quantization.convert(model_fp32_prepared)

    # compute quantized result with pytorch quantization
    pytorch_res = model_int8.fc(model_int8.quant(input_fp32)).int_repr().float()

    # compare with manual computation
    # convert input to unsigned int8
    quant_input_unsigned = quantize_tensor_unsigned(input_fp32, model_int8.quant(input_fp32).q_scale(),
                                                    model_int8.quant(input_fp32).q_zero_point())
    # compute manual quantized linear operation
    unsigned_result = quantize_lin_layer(x=quant_input_unsigned.tensor,
                                         scale_x=quant_input_unsigned.scale,
                                         zp_x=quant_input_unsigned.zero_point,
                                         weight_data=model_int8.fc.weight().int_repr(),
                                         weight_scale=model_int8.fc.weight().q_scale(),
                                         weight_zero_point=model_int8.fc.weight().q_zero_point(),
                                         scale_out=model_int8.fc.scale, zero_point_out=model_int8.fc.zero_point,
                                         )

    assert torch.equal(unsigned_result, pytorch_res), f"{unsigned_result} is not equal to {pytorch_res}."


if __name__ == '__main__':
    main()

1 Like

Assuming that both your input and weight are in the integer representation. Remove their respective zero_point and perform the integer matrix-multiplication. Rescale the result with your output_multiplier_float and add the output zero_point. Finally, clamp the result to be represented as a uint8. :slightly_smiling_face:

def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
                       scale_out, zero_point_out):
    """
    Implementation of a quantized linear layer without bias.

    :param x: quantized input
    :param scale_x: scale of quantized input
    :param zp_x: zero point of quantized input
    :param weight_data: quantized weight
    :param weight_scale: scale for quantized weight
    :param weight_zero_point: zero_point for quantized weight
    :param scale_out: scale of quantized output
    :param zero_point_out: zero point of quantized output
    :return: requantized output of linear layer
    """
    return torch.max(
            torch.tensor(0),
            torch.min(
                torch.tensor(255),
                torch.round(torch.nn.functional.linear(
                    x.float() - zp_x, weight_data.float() - weight_zero_point)
                ) * (scale_x * weight_scale) / scale_out + zero_point_out)
        )
1 Like

I get the closest result to pytorch’s quantized linear layer output with this function:

def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
                       scale_out, zero_point_out):
    """
    Implementation of a quantized linear layer without bias.

    :param x: quantized input
    :param scale_x: scale of quantized input
    :param zp_x: zero point of quantized input
    :param weight_data: quantized weight
    :param weight_scale: scale for quantized weight
    :param weight_zero_point: zero_point for quantized weight
    :param scale_out: scale of quantized output
    :param zero_point_out: zero point of quantized output
    :return: requantized output of linear layer
    """
    return torch.max(
        torch.tensor(0).double(),
        torch.min(
            torch.tensor(255).double(),
            torch.round(
                torch.nn.functional.linear(
                    x.double() - zp_x,
                    weight_data.double() - weight_zero_point
                ) * (scale_x * weight_scale) / scale_out
            ) + zero_point_out
        )
    )

However, few values still mismatch. Could there be rounding errors?