How does quantized conv2d handle scale and zero_point?

Dear Pytorch community,

For my research, I am recreating the convolution function using my code. However, my manual conv2d result cannot match what’s given by torch.nn.quantized.functional. I appreciate any advice on how a convolution work on 3-dimension input and kernels under quantized parameters.

For a specific example, I’ve been working on the first conv layer of a Resnet50 model. I have quantized input images of (1, 3, 224, 224). They are padded so that the input to the conv layer is (1, 3, 230, 230). The convolution has 64 kernels of (3, 7, 7) and a stride of 2, which gives an output of (64, 112, 112). Basically, I’m trying to compare my result, manual_res, with the element output_ref from the same location in the output tensor at every iteration of the kernel. The indexing shouldn’t be a problem here because I tried qF.conv2d on the two tensors and it is able to match with output_ref. Here’s what I’ve written as the convolution loop.

import torch.nn.functional as F
from torch.nn.quantized import functional as qF

conv1_pad = (3, 3, 3, 3)
after_pad = F.pad(after_quant, conv1_pad, "constant", 0)  # input to the conv layer
print("after pad: ", after_pad.shape)   # 1, 3, 230, 230

my_conv1_result = torch.zeros(after_conv1.shape)

for c in range(0, conv1_weight.shape[0]):  # 64 output channels
    kernel = conv1_weight[c]  # 3x7x7 kernel
    target_y = 0  # index in result tensor
    
    for start_y in range(0, after_pad.shape[2] - 7, 2):  # 112
        target_x = 0
#         print(start_y, end=", ")
        for start_x in range(0, after_pad.shape[3] - 7, 2):  # 112
            input_tensor = after_pad[0, :, start_x:start_x + 7, start_y:start_y + 7]  # 3x7x7
            manual_res   = torch.tensor(0, dtype=torch.int8)  # uint8 for activation
            output_ref   = after_conv1[0, c, target_x, target_y]
        
#             print(input_tensor.int_repr())
#             print(kernel.int_repr())
            print("output_ref:", output_ref.int_repr(), end="  =====  ")
      
            for i in range(kernel.shape[0]):  # 3
                for j in range(kernel.shape[1]):  # 7
                    for k in range(kernel.shape[2]):  # 7
                        #####################
                        # Multiply and accumulate
                        temp = (input_tensor.int_repr()[i, j, k] - input_tensor.q_zero_point()) * (kernel.int_repr()[i, j, k] - kernel.q_zero_point())
                        manual_res = manual_res + temp
                        #####################


            manual_res = conv1.zero_point + (manual_res * (input_tensor.q_scale() * kernel.q_scale() / conv1.scale)).round()
            manual_res = 255 if manual_res > 255 else 0 if manual_res < 0 else manual_res
            print("manual_res:", manual_res, end="  =====  ")
            my_conv1_result[0, c, target_x, target_y] = manual_res
            
            qf_conv_res = qF.conv2d(input_tensor.reshape((1, 3, 7, 7)), kernel.reshape((1, 3, 7, 7)), bias=torch.tensor([0], dtype=torch.float), scale=conv1.scale, zero_point=conv1.zero_point)
            # conv1 is the first conv layer with its scale and zp
            print("qF.conv2d ref:", qf_conv_res.int_repr())
        
            target_x += 1
            
        target_y += 1

The printed shows a mismatch between my manual_res and output_ref or qF.conv2d ref.

output_ref: tensor(66, dtype=torch.uint8)  =====  manual_res: tensor(75.)  =====  qF.conv2d ref: tensor([[[[66]]]], dtype=torch.uint8)
output_ref: tensor(66, dtype=torch.uint8)  =====  manual_res: tensor(79.)  =====  qF.conv2d ref: tensor([[[[66]]]], dtype=torch.uint8)
output_ref: tensor(66, dtype=torch.uint8)  =====  manual_res: tensor(64.)  =====  qF.conv2d ref: tensor([[[[66]]]], dtype=torch.uint8)

I think the problem goes with the handling of scales and zero_point during the loop. I am referring to this paper: https://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf

It states that

.
In my case, q3 would be manual_res, Z1-3 and S1-3 are input’s, kernel’s and conv1’s zero_point and scales, q1, q2 are elements from the input and kernel. I am desperate to know why my implementation is not correct. (A side note, I tried my implementation with random 2-dimension tensors with the same handling of scales and zp, it seems to work fine.)

Again, thanks for your help!