Understanding quantized linear layer

I get the closest result to pytorch’s quantized linear layer output with this function:

def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
                       scale_out, zero_point_out):
    """
    Implementation of a quantized linear layer without bias.

    :param x: quantized input
    :param scale_x: scale of quantized input
    :param zp_x: zero point of quantized input
    :param weight_data: quantized weight
    :param weight_scale: scale for quantized weight
    :param weight_zero_point: zero_point for quantized weight
    :param scale_out: scale of quantized output
    :param zero_point_out: zero point of quantized output
    :return: requantized output of linear layer
    """
    return torch.max(
        torch.tensor(0).double(),
        torch.min(
            torch.tensor(255).double(),
            torch.round(
                torch.nn.functional.linear(
                    x.double() - zp_x,
                    weight_data.double() - weight_zero_point
                ) * (scale_x * weight_scale) / scale_out
            ) + zero_point_out
        )
    )

However, few values still mismatch. Could there be rounding errors?

1 Like