Quantized Conv2d gives different result from the Caffe2's Int8Conv with the same weights

  1. Create module with 1 Conv2d (with constant weights), Add quant stubs
  2. Save results of inference in init_res
  3. Perform post-training static quantization
  4. Save results of quantized inference in q_res

init_res and q_res are 100% different

  1. Build caffe2 int8 network with 1 Conv2d and weights from the quantized PyTorch module. (with necessary NCHW -> NHWC transpositions)
  2. Save inference in caffe_res

init_res and caffe_res are similar with rtol=0.05

Question: what am I doing wrong? why quantized PyTorch Conv gives so different results?

P.S. in every inference and calibration the same tensor x is used.

Code:

import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

print(torch.__version__)
torch.manual_seed(123)

import numpy as np
from caffe2.python import workspace, model_helper

class QuantizedConv(nn.Module):
    def __init__(self):
        super(QuantizedConv, self).__init__()
        self.conv = nn.Conv2d(3, 32, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        
        m = self.conv
        self.conv.weight.data.fill_(0.01)
        nn.init.zeros_(m.bias)
        
    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        return self.dequant(x)
    

model = QuantizedConv()
x = torch.randn(1, 3, 320, 320, dtype=torch.float)

model = QuantizedConv()
model.eval()

init_res = model(x).detach().numpy()

model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
print(model.qconfig)
model = torch.quantization.prepare(model, inplace=True);
with torch.no_grad():
    for i in range(10):
        model(x)
torch.quantization.convert(model, inplace=True);

q_res = model(x).detach().numpy()
# np.testing.assert_allclose( q_res, init_res, rtol=0.05, atol=0.04) 
# <- error
# Not equal to tolerance rtol=0.05, atol=0.04
# (mismatch 99.97626676159962%)
#  x: array([-0.159745, -0.194942, -0.138084, ..., -0.138084, -0.17599 ,
#       -0.140792], dtype=float32)
# y: array([-0.049151, -0.018199,  0.065673, ...,  0.056086, -0.027169,
#       -0.023959], dtype=float32)

Caffe2 part

workspace.ResetWorkspace()
workspace.FeedBlob("x1", x)
net = model_helper.ModelHelper(name="test")

net.net.NCHW2NHWC("x1", "x") # Transpose input tensor
net.net.Int8Quantize("x", 
                     "x_int8", 
                     Y_scale=model.quant.scale.numpy()[0], 
                     Y_zero_point=model.quant.zero_point.numpy()[0],
                     )

W = model.conv.weight().int_repr().detach().numpy()
W = np.transpose(W, [0, 2, 3, 1]) # Transpose conv weights tensor 
net.param_init_net.Int8GivenTensorFill(
                                [],
                                'conv_w',
                                shape=(32, 5, 5, 3), #model.conv.weight().int_repr().detach().numpy().shape,
                                values=W.tobytes(),
                                Y_scale=model.conv.weight().q_scale(),
                                Y_zero_point=0,
)
net.param_init_net.Int8GivenIntTensorFill(
                            [],
                            'conv_b',
                            shape=[32, ], # model.conv.bias().detach().numpy().shape,
                            values= np.zeros(32, dtype=np.int32), #model.conv.bias().detach().numpy(),
                            Y_scale=1.,
                            Y_zero_point=0
)

kwargs = {
    "kernel":5,
    "stride":2,
    "pad":1,
    "order":"NHWC",
    "Y_scale":model.conv.scale,
    "Y_zero_point":model.conv.zero_point,
}
net.net.Int8Conv(["x_int8", "conv_w", 'conv_b'], "conv_1_q", **kwargs)
net.net.Int8Dequantize("conv_1_q", 
                       "res",
                       Y_scale=1., 
                       Y_zero_point=0,)
workspace.RunNetOnce(net.param_init_net)
workspace.CreateNet(net.net)
workspace.RunNet(net.name)

caffe_res = workspace.FetchBlob("res")
caffe_res = np.transpose(caffe_res, [0, 3, 1, 2]) # Transpose output tensor

np.testing.assert_allclose( caffe_res, init_res, rtol=0.05, atol=0.1)

And I also noted that PyTorch results max value is 0

print(f'min {init_res.min()}, max {init_res.max()}')
print(f'min {q_res.min()}, max {q_res.max()}')
print(f'min {caffe_res.min()}, max {caffe_res.max()}')

# min -0.3654041886329651, max 0.36045965552330017
# min -0.36551710963249207, max 0.0
# min -0.36551710963249207, max 0.32490411400794983

with other initialization of conv2d

self.conv.weight.data.uniform_(-0.05, 0.05)

PyTorch quantized results become close enough to init_res, but caffe conv is very different.

The issue is that default pytorch uses fbgemm as a backend for quantized operators. You can check this by

print(torch.backends.quantized.engine)

Caffe2 int8conv uses the qnnpack engine by default. To ensure that pytorch and c2 match, you will need to do the following:
Once I do that the outputs match.


model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
torch.backends.quantized.engine = 'qnnpack'
# Set engine to qnnpack
print(model.qconfig)
model = torch.quantization.prepare(model, inplace=True);
1 Like

Is it possible to call int8 Conv in caffe2 with signed weights ?

self.conv.weight.data.uniform_(-0.05, 0.05)

It should work, qnnpack can take in both signed and unsigned weights in floating point. Internally, the weights are represented as uint8 values with a scale and zero-point mapping from float to quantized values.

I wasn’t able to run c2 Int8Conv with signed weights
Should I use Int8ConvPackWeight to make it work ?

I tried to look at test code (because I didn’t find any docs except for code)

And before weight packaging they used uint8 .