The weights of the int8 model do not match the qat model

I thought that the parameters of the QAT model should be identical to the final int8 model, but I have found that this is not actually the case. I wrote a simple code to verify this,

loss_fn = torch.nn.L1Loss()

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 1, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

if __name__ == "__main__":
    model_fp32 = M()
    optim = torch.optim.Adam(model_fp32.parameters(), lr=1e-2)

    mean_loss = []
    for i in range(200):
        inp_tensor = torch.rand(1024).reshape(1024, 1, 1, 1).float()
        oup_tensor = -inp_tensor * 2
        output = model_fp32(inp_tensor)
        loss = loss_fn(output, oup_tensor)
        loss.backward()
        optim.step()
        optim.zero_grad()
        mean_loss += [loss.item()]

    qconfig_mapping = QConfigMapping()\
        .set_global(get_default_qat_qconfig("qnnpack"))\

    example_inputs = np.random.rand(2).reshape(2, 1, 1, 1)
    model_fp32_prepared = quantize_fx.prepare_qat_fx(
        copy.deepcopy(model_fp32).train(), qconfig_mapping, example_inputs)

    optim = torch.optim.Adam(model_fp32_prepared.parameters(), lr=1e-4)

    model_fp32_prepared.train()
    mean_loss = []
    for i in range(200):
        inp_tensor = torch.rand(1024).reshape(1024, 1, 1, 1).float()
        oup_tensor = -inp_tensor * 2
        output = model_fp32_prepared(inp_tensor)
        loss = loss_fn(output, oup_tensor)
        loss.backward()
        optim.step()
        optim.zero_grad()
        mean_loss += [loss.item()]
    
    model_fp32_prepared.eval()
    model_fp32_prepared.apply(torch.ao.quantization.disable_observer)
    model_fp32_prepared.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    model_int8 = quantize_fx.convert_fx(copy.deepcopy(model_fp32_prepared))
    print(model_int8.conv.weight().item(),\
          model_fp32_prepared.conv.weight_fake_quant(model_fp32_prepared.conv.weight).item())

In my case, the model_int8 weight: -2.0080461502075195, and the model_fp32_prepared: -2.000202178955078

I read the source code of qat observer and the convert code of quantize_fx.convert_fx, the method of quantized weight has a little different.

the convert code, observer.py: _calculate_qparams

max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, self.eps)

In this case, the max_val = min_val = -2.000202178955078, and the scale will be min_val / 127.5, but the int value is -128

I’m curious why not divide by 128, it will get a more accurate result.

Hi @Novelfor

I thought that the parameters of the QAT model should be identical to the final int8 model

I don’t think you should expect that to be the case. We simulate int8 with the fp32 value, so some small delta is expected.

In that code you mentioned, I believe we cannot divide by the int value because then we would not be calculate the gradients correctly, since we need to backprop through this scales calculation.

cc @andrewor in case I’m missing something here.

I mean the parameters after weight_fake_quant should be identical to int8 model.
The weight after weight_fake_quant should be
round(conv.weight / conv.weight_fake_quant.scale) * conv.weight_fake_quant.scale
in my case:
conv.weight = -2.0032155513763428
weight_fake_quant.scale = 0.007843930274248123
so, the value is round(-255.38416091649415) * 0.007843930274248123 = 2.0002022199332714

But in model_int8
image

In that code you mentioned, I believe we cannot divide by the int value because then we would not be calculate the gradients correctly, since we need to backprop through this scales calculation.

First, this code is how to compute scale for int8 model, not qat model, we do not need to backprop.
And, i mean the value is error when divide 127.5, when i inference qat model, the parameter is min_val, but in int8 model, the weight is -128 *scale = -128 * (min_val / 127.5), I think the compute method should be same between QAT model and int8 model.