I get the closest result to pytorch’s quantized linear layer output with this function:
def quantize_lin_layer(x: torch.quint8, weight_data: torch.qint8, weight_scale, weight_zero_point, scale_x, zp_x,
scale_out, zero_point_out):
"""
Implementation of a quantized linear layer without bias.
:param x: quantized input
:param scale_x: scale of quantized input
:param zp_x: zero point of quantized input
:param weight_data: quantized weight
:param weight_scale: scale for quantized weight
:param weight_zero_point: zero_point for quantized weight
:param scale_out: scale of quantized output
:param zero_point_out: zero point of quantized output
:return: requantized output of linear layer
"""
return torch.max(
torch.tensor(0).double(),
torch.min(
torch.tensor(255).double(),
torch.round(
torch.nn.functional.linear(
x.double() - zp_x,
weight_data.double() - weight_zero_point
) * (scale_x * weight_scale) / scale_out
) + zero_point_out
)
)
However, few values still mismatch. Could there be rounding errors?