How to use quantize_per_tensor

I have a model which is trained in Kaldi and I’m able to load the model parameters in PyTorch as tensors.
I am trying to perform post-quantization of the weight matrices and I’ve tried to use the quantize_per_tensor function.
For. ex:

a = torch.rand(10)
b = torch.rand(10)
scale_a = (max_a - min_a) / (qmax - qmin)
zpt_a = qmin - min_a / scale_a
scale_b = (max_b - min_b) / (qmax - qmin)
zpt_b = qmin - min_b / scale_b
a_quant = torch.quantize_per_tensor(a, scale_a, -127, torch.qint8)
b_quant = torch.quantize_per_tensor(b, scale_b, -127, torch.qint8)
a_quant + b_quant

When I add the 2 quantized tensors, I get the below error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Could not run 'aten::add.Tensor' with arguments from the 'QuantizedCPU' backend. 'aten::add.Tensor' is only available for these backends: [CPU, CUDA, MkldnnCPU, SparseCPU, SparseCUDA, Meta, Named, Autograd, Profiler, Tracer].

It seems that I can convert fp32 to int8 but not perform any integer arithmetic .
Any help as to how to use this will be appreciated.

Thanks!

1 Like

@aprasad: If you are just looking to quantize weights only and want to keep activations in fp32, please look into dynamic quantization. It does exactly that and you it will calculate scale/zero_points automatically for you as well.

BTW if you want to use add (or other such operations) on quantized tensor you can use in the following way.

qfn = torch.nn.quantized.QFunctional()
qfn.add(a_quant, b_quant)

https://pytorch.org/docs/stable/quantization.html#qfunctional

1 Like

@dskhudia: Thank you for the reply. The quantized functional works.
I want to quantize both the weights and the activations and run inference in Pytorch with my custom class.

Is there a function that calculates the scale/zero_points automatically for that?

@dskhudia: I tried using the qfn to multiply 2 matrices and I get the below error.

For a of shape (1,10) and b of shape (10,20) if i do

qfn.mul(a_quant, b_quant)

I get,

r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 1

Is there any function for matrix multiplication of quantized tensors?

you can use linear for that I think.

1 Like

Hi, I have a question about the quantization scheme.

when I do xq = torch.quantize_per_tensor(x, scale = 0.25, zero_point = 15, dtype=torch.quint8),
the result is a “quantization_scheme=torch.per_tensor_affine” tensor. I wonder how can I change the scheme to per_tensor_symmetric tensor?

We do not have per_tensor_symmetric tensor in the backend actually since per_tensor_symmetric can be represented by per_tensor_affine tensor, e.g. a torch.per_tensor_symmetric, torch.qint8 tensor with a scale would be the same as a torch.per_tensor_symmetric, torch.qint8 with the same scale and a zero_point of 0. You can use per_tensor_symmetric in observers though.

Thanks for your reply. I’m checking histogram observer recently. But I wonder, is the norm computation correct? Why there are 3 norms computed ?

def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int):
    r"""
    Compute the quantization error if we use start_bin to end_bin as the
    min and max to do the quantization.
    """
    # print('at _compute_quantization_error, next_start, next_end', next_start_bin, next_end_bin)

    bin_width = (self.max_val.item() - self.min_val.item()) / self.bins

    # compute new bin_width
    dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins      # divided by 256, quantized range is [-128, 127]
    if dst_bin_width == 0.0:
        return 0.0

    src_bin = torch.arange(self.bins).to(self.device)
    # distances from the beginning of first dst_bin to the beginning and
    # end of src_bin
    src_bin_begin = (src_bin - next_start_bin) * bin_width
    src_bin_end = src_bin_begin + bin_width

    # which dst_bins the beginning and end of src_bin belong to?
    dst_bin_of_begin = torch.clamp(src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1)
    dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width

    dst_bin_of_end = torch.clamp(src_bin_end // dst_bin_width, 0, self.dst_nbins - 1)
    dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width

    density = self.histogram / bin_width

    norm = torch.zeros(self.bins)

    delta_begin = src_bin_begin - dst_bin_of_begin_center
    delta_end = dst_bin_width / 2
    # print('type of delta_begin, delta_end', type(delta_begin), type(delta_end))

    delta_begin= delta_begin.to(self.device)

    ## compute norm from 3 parts: begin of each new bin, center of each new bin, end of each new bin
    norm += self._get_norm(delta_begin, torch.ones(self.bins) * delta_end, density)

    norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
        torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
    )

    dst_bin_of_end_center = (
        dst_bin_of_end * dst_bin_width + dst_bin_width / 2
    )

    delta_begin = -dst_bin_width / 2
    delta_end = src_bin_end - dst_bin_of_end_center
    norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)

    return norm.sum().item()

cc @hx89 who implmeneted histogram observer