Conv2d versus nn.quantized.Conv2D

Hi Team,

I am trying to understand the output difference between Conv2d and nn.quantized.Conv2D(qconv2d). I create random input and weight tensor values falling within the range of int8. Then, I calculate the output of a conv2d. I have a custom conv2d method that is identical to conv2d but uses a fold and unfold functions for performing convolution. Both methods give exactly the same result. Now, I try to create quantized tensors for the same input and weights by calculating scale and zero_point using affine quantization scheme as described here ( Practical Quantization in PyTorch | PyTorch. I know I am not adding the bias but the output of qconv2d is way off from the normal conv2d. Can anyone please explain to me the reason behind it? Besides, when a tensors input and weights of qint8 are assigned, are the operations performed on the qint8 tensors or are they converted into some other format in the Pytorch cpp code?

import random
random.seed(0)
torch.manual_seed(0)

# create random input and weight with values ranging in int8
input = torch.randint(0,20,(1,1,3,3), dtype= torch.float32)
weight = torch.randint(0,20,(1,1,3,3), dtype= torch.float32)

# Normal Convolution Layer 
m = torch.nn.Conv2d(1, 1, 3, stride= 1)
m.weight.data= weight
conv_ouput = m(input)

print("Convolution Output :", conv_ouput)

#custom convolution method output
obj = Test()
out_shape = obj.conv2d_output_shape((3,3),kernel_size=3,stride=1,pad=0)
unfolded = torch.nn.functional.unfold(
            input,
            (3, 3),
            padding=0,
            stride=1,
        )
conv_output = (
            
            obj.matmul(unfolded.transpose(1, 2), weight.view(weight.size(0), -1).t())
            .transpose(1, 2)
        )
conv_output_fold = torch.nn.functional.fold(
            conv_output, out_shape, (1, 1)
        )

conv_output_fold = conv_output_fold+m.bias
print("Custom Convolution Output", conv_output_fold)


# quantized convolutional layer 
bitwidth = 8
qm = torch.nn.quantized.Conv2d(1, 1, 3, stride= 1)

# calculate scale and zero point for input and weight 
weight_alpha = weight.min()
weight_beta = weight.max()
weight_scale = (weight_beta-weight_alpha)/(2**bitwidth-1-1) 
weight_zero_point = -(weight_alpha/weight_scale - 0)

input_alpha = input.min()
input_beta = input.max()
input_scale = (input_beta-input_alpha)/(2**bitwidth-1-1) 
input_zero_point = -(input_alpha/input_scale - 0)

print("Bias", m.bias.item())

quantized_weight = torch.quantize_per_tensor(weight,scale=weight_scale,zero_point=weight_zero_point,dtype=torch.qint8)
qm.weight().data = quantized_weight
quantized_input = torch.quantize_per_tensor(input, scale=input_scale, zero_point=input_zero_point, dtype= torch.quint8)

print("Input tensor", input)
print("Weight tensor", weight)
print("Normal Convolution Output", conv_ouput)
print("Quantized Input", quantized_input)
print("Quantized Weight", quantized_weight)
print("Quantized Input Integer Representation", quantized_input.int_repr())
print("Quantized Weight Integer Representation", quantized_weight.int_repr())
quantized_output = qm(quantized_input)
print("Quantized Convolution Output", quantized_output)
Convolution Output : tensor([[[[803.6908]]]], grad_fn=<ConvolutionBackward0>)
Custom Convolution Output tensor([[[[803.6908]]]], grad_fn=<AddBackward0>)
Bias -0.309223473072052
Input tensor tensor([[[[ 4., 19., 13.],
          [ 0.,  3., 19.],
          [ 7.,  3., 17.]]]])
Weight tensor tensor([[[[ 3.,  1.,  6.],
          [16., 19., 18.],
          [16., 16.,  8.]]]])
Normal Convolution Output tensor([[[[803.6908]]]], grad_fn=<ConvolutionBackward0>)
Quantized Input tensor([[[[ 3.9646, 19.0000, 13.0157],
          [ 0.0000,  2.9921, 19.0000],
          [ 7.0315,  2.9921, 16.9803]]]], size=(1, 1, 3, 3),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.07480315119028091, zero_point=0)
Quantized Weight tensor([[[[2.9764, 0.9921, 6.0236],
          [9.9921, 9.9921, 9.9921],
          [9.9921, 9.9921, 8.0079]]]], size=(1, 1, 3, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.07086614519357681,
       zero_point=-14)
Quantized Input Integer Representation tensor([[[[ 53, 254, 174],
          [  0,  40, 254],
          [ 94,  40, 227]]]], dtype=torch.uint8)
Quantized Weight Integer Representation tensor([[[[ 28,   0,  71],
          [127, 127, 127],
          [127, 127,  99]]]], dtype=torch.int8)
Quantized Convolution Output tensor([[[[0.]]]], size=(1, 1, 1, 1), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

@HDCharles @Vasiliy_Kuznetsov @jerryzh168

Thanks

One thing that is missing from the code is setting the scale and zero_point for the output of qm. Usually that is done by attaching an observer to the conv, calibrating/doing QAT to calculate output statistics, and then using those statistics to quantize the output.

Check out the from_float method here: pytorch/conv.py at master · pytorch/pytorch · GitHub for logic on how the production path does it. In your code, you could attach an observer to your conv, populate that observer’s stats either via calibration or with manually setting the variables, and then use the from_float method to create your quantized conv.

thank you, @Vasiliy_Kuznetsov, for your quick response. I will look into this. One question on quantized convolution what does packed_params = torch.ops.quantized.conv2d_prepack return? Can you also point me to the information on how convolution is performed for quantized tensors with scale and zero_point attributes?