Implement inference of quantized network with primitive operations

I want to implement quantized network in pure C. One of the purposes is to get full understanding on how the operations with quantized tensors work.

I made PQT with Renset-18 architecture and got good accuracy with fbgemm backend.

Now I am struggling to replicate the operations. I decided that the simplest to start is addition in Renset block.

self.skip_add = nn.quantized.FloatFunctional()

And during inference I can add two tensors via

out1 = self.skip_add.add(x1, x2)

where x1 and x2 are tensors of torch.Tensor type, quantized with fbgemm backend during post training quantization procedure.

I expected out2_int = x1.int_repr() + x2.int_repr() should be the same as out1.int_repr() (with probably need of clamping in the needed range). However that is not the case.

Can anyone please provide me with any information on how to implement operations with quantized tensors?

Below I dump the example outputs.

print(x1)
    ...,
      [-0.0596, -0.0496, -0.1390,  ..., -0.0596, -0.0695, -0.0099],
      [-0.0893,  0.0000, -0.0695,  ...,  0.0596, -0.0893, -0.0298],
      [-0.1092,  0.0099,  0.0000,  ..., -0.0397, -0.0794, -0.0199]]]],
   size=(1, 256, 14, 14), dtype=torch.quint8,
   quantization_scheme=torch.per_tensor_affine, scale=0.009925744496285915,
   zero_point=75)
print(x2)
      ...,
      [ 0.1390, -0.1669, -0.0278,  ..., -0.2225, -0.0556, -0.1112],
      [ 0.0000, -0.1669, -0.0556,  ...,  0.0556,  0.1112, -0.2781],
      [ 0.1390,  0.1669,  0.0278,  ...,  0.2225,  0.4171,  0.0834]]]],
   size=(1, 256, 14, 14), dtype=torch.quint8,
   quantization_scheme=torch.per_tensor_affine, scale=0.02780967578291893,
   zero_point=61)
print(x1.int_repr())
      ...,
      [69, 70, 61,  ..., 69, 68, 74],
      [66, 75, 68,  ..., 81, 66, 72],
      [64, 76, 75,  ..., 71, 67, 73]]]], dtype=torch.uint8)
print(x2.int_repr())
      ...,
      [66, 55, 60,  ..., 53, 59, 57],
      [61, 55, 59,  ..., 63, 65, 51],
      [66, 67, 62,  ..., 69, 76, 64]]]], dtype=torch.uint8)
print(self.skip_add.add(x1, x2))
      ...,
      [ 0.0904, -0.2109, -0.1808,  ..., -0.2712, -0.1205, -0.1205],
      [-0.0904, -0.1808, -0.1205,  ...,  0.1205,  0.0301, -0.3013],
      [ 0.0301,  0.1808,  0.0301,  ...,  0.1808,  0.3314,  0.0603]]]],
   size=(1, 256, 14, 14), dtype=torch.quint8,
   quantization_scheme=torch.per_tensor_affine, scale=0.03012925386428833,
   zero_point=56)
print(self.skip_add.add(x1, x2).int_repr())
      ...,
      [59, 49, 50,  ..., 47, 52, 52],
      [53, 50, 52,  ..., 60, 57, 46],
      [57, 62, 57,  ..., 62, 67, 58]]]], dtype=torch.uint8)
 print(x1.int_repr() + x2.int_repr())
      [135, 125, 121,  ..., 122, 127, 131],
      [127, 130, 127,  ..., 144, 131, 123],
      [130, 143, 137,  ..., 140, 143, 137]]]], dtype=torch.uint8)

I have found that I was misconceptioned about quantized values as soon as byte value shows quant of its specific domain. I used the following paper as a guide https://arxiv.org/pdf/1712.05877.pdf (though it’s a little superior to pytorch’s implementation because it provides Integer-Arithmetic-Only quantization, why many operations in pytorch quantization stay float).
With that I was able to successfully reimplement convolution and fully connected layers, but nn.quantized.FloatFunctional().add() still resists to be implemented. I I thought it would be the easiest layer, heh.

At the moment tryed the following reasoning:

S_3(q_3 - Z_3) = S_1(q_1 - Z_1) + S_2(q_2 - Z_2)

from where

q_3 = [S_1(q_1 - Z_1) + S_2(q_2 - Z_2)] / S_3 + Z_3

In that case almost half of the elements appear to be the same as in pytorch, but other half can change drastically. A lot of misses come from values needed to be clipped.

def manual_addition(x1, x2, add_layer):
#add_layer is of the type of nn.quantized.FloatFunctional()
q1 = x1.int_repr().numpy()
q2 = x2.int_repr().numpy()
z1 = x1.q_zero_point()
z2 = x2.q_zero_point()
s1 = x1.q_scale()
s2 = x2.q_scale()
z3 = add_layer.zero_point
s3 = add_layer.scale

q3_1 = s1 * (q1 - z1)
q3_2 = s2 * (q2 - z2)

qres = q3_1 + q3_2
qres = qres / s3 + z3
q3_int32 = (qres).round()

q3 = q3_int32.clip(0.255).astype(np.uint8) # many misses are from clipped values
gt_res = add_layer.add(x1, x2)
gt_res_int = gt_res.int_repr().numpy()
calc_hit_rate(q3, gt_res_int) #hit rate =0.51

If anyone has any clue on how clipping should be done here, I would appretiate your help so much.

Hi @roman.spb , you can use the following equations to convert between the floating point and quantized domains:

q = clamp(std::nearbyint(fp / scale) + zp, qmin, qmax)
fp = float(q - zp) * scale
1 Like

Hi Vasiliy! Thank you so much for your help! These equations work for operating with a single number. But when I try to add numbers in quantized domain and then clamp according to formulas I still get very different result from pytorch implementation of

nn.quantized.FloatFunctional().add()

So it maybe some different issue than just clamping.

I’ve tried adding two numbers in float domain and still had no luck in reproducing add() behaviour.

Maybe you can help me to understand how this function works internally? Or maybe where I can find the source code implementing this function?

The implementation is:

Hi HDCharles.
Code you mention is not what’s needed, it refers to C++ compiled implementation of ops.quantized.add. It was very complicated to figure out where is the actual code.
I found it via googling in issues to pytorch in github repository

I still not yet reproduced these operations, because implementation is obscured by other fucntions like Vectorized::dequantize which obscure the actual implementation and there are lots of different implementations of that.

It looks to me like its just dequantizing, adding together, and then requantizing:

import torch

x = torch.randn(10,10)
y = torch.randn(10,10)

#arbitrary scales and zero_points
zp_x = 1
zp_y = 2
zp_z = 3

s_x = .1
s_y = .2
s_z = .3

#quantize tensors
xq = torch.quantize_per_tensor(x, s_x, zp_x, torch.qint8)
yq = torch.quantize_per_tensor(y, s_y, zp_y, torch.qint8)

#setup add operation
add_op = torch.nn.quantized.QFunctional()
add_op.scale = s_z
add_op.zero_point = zp_z

zq_qfunc = add_op.add(xq,yq)
print("QFunctional output", zq_qfunc)

#manually do operation
xdq = xq.dequantize()
ydq = yq.dequantize()

#add together
zdq = xdq+ydq

#requantize
zq_manual = torch.quantize_per_tensor(zdq, s_z, zp_z, torch.qint8)

print("Manual output", zq_manual)

print("difference", zq_manual.int_repr()-zq_qfunc.int_repr())

You can get the specifics about the quant/dequant process formulas here: Quantization API Reference — PyTorch master documentation

1 Like

Hd Charles, thank you very much for your help!

It appeared as simple as you mentioned. Now it makes sense why it is named floatfunctional. I thought I tried it at the beginning of efforts, but I think I made some mistake.

To add: dequantize() and quantize_per_tensor are comparatively straightforward to implement in primitive operations.

def manual_addition(xq1_int, scale1, zp1, xq2_int, scale2, zp2, scale_r, zp_r):
       # from int8 to float
       xdq = scale1 * (xq1_int.astype(np.float) - zp1)
       ydq = scale2 * (xq2_int.astype(np.float) - zp2)
       
       # float addition
       zdq = xdq + ydq
      
      # from float to int8
       zq_manual_int = (((zdq / scale_r).round()) + zp_r).round() 
       return zq_manual_int #clipping might be needed