Accuracy drop after model quantization

Hello
I am trying to quantize the stacked hourglasses model used for 2d pose estimation using static post training quantization in eager mode. However, after quantization the accuracy decreases by almost 20 percent. I am trying to debug following this tutorial pytorch numeric suite, and it seems to me that there is some problem with the batch normalization layers but I am not sure about it.

This is the code I use for quantization (I know I am not fusing the layers)

def static(model, dataloader):
    model.eval()
    model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
    torch.ao.quantization.prepare(model, inplace=True)
    for inputs, labels, masks in dataloader:
        inputs = inputs.to(device, dtype=torch.float32)
        model(inputs)
    torch.ao.quantization.convert(model, inplace=True)
    return model

and this is part of the model after quantization

(0): SeBottleneck(
      (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): QuantizedConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.5101176500320435, zero_point=77)
      (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.8481943607330322, zero_point=83, padding=(1, 1))
      (bn3): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), scale=0.13069623708724976, zero_point=92)
      (se): SeBlock(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): QuantizedLinear(in_features=128, out_features=8, scale=0.05465518683195114, zero_point=33, qscheme=torch.per_channel_affine)
          (1): ReLU(inplace=True)
          (2): QuantizedLinear(in_features=8, out_features=128, scale=0.057319797575473785, zero_point=127, qscheme=torch.per_channel_affine)
          (3): Sigmoid()
        )
        (mul): QFunctional(
          scale=0.021492326632142067, zero_point=83
          (activation_post_process): Identity()
        )
      )
      (downsample): Sequential(
        (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), scale=0.025261757895350456, zero_point=71)
      )
      (add): QFunctional(
        scale=0.030607327818870544, zero_point=62
        (activation_post_process): Identity()
      )
    )

when I compare the output of the quantized modules with the output of the original modules using the function compare_model_stub() and evaluating the error using the compute_error() function in the pytorch numeric suite tutorial, I observe that there is a big difference between batch normalization layers and all the other layers of the network.

This are the numbers for some of the batch normalization layers

hg.0.hg.2.1.3.bn2.stats tensor(1.2797)
hg.0.hg.2.1.3.bn3.stats tensor(0.1870)
hg.0.hg.2.2.0.bn1.stats tensor(3.4588)
hg.0.hg.2.2.0.bn2.stats tensor(1.0052)
hg.0.hg.2.2.0.bn3.stats tensor(0.8627)
hg.0.hg.2.2.1.bn1.stats tensor(3.1735)
hg.0.hg.2.2.1.bn2.stats tensor(0.8259)
hg.0.hg.2.2.1.bn3.stats tensor(1.1860)
hg.0.hg.2.2.2.bn1.stats tensor(3.0059)
hg.0.hg.2.2.2.bn2.stats tensor(0.7214)

and these are the numbers for some of the convolutional layers

`hg.1.hg.3.1.1.conv2.stats tensor(33.3547)
hg.1.hg.3.1.1.conv3.stats tensor(28.7391)
hg.1.hg.3.1.2.conv1.stats tensor(31.8503)
hg.1.hg.3.1.2.conv2.stats tensor(32.0367)
hg.1.hg.3.1.2.conv3.stats tensor(30.5304)
hg.1.hg.3.1.3.conv1.stats tensor(32.4438)
hg.1.hg.3.1.3.conv2.stats tensor(34.6818)
hg.1.hg.3.1.3.conv3.stats tensor(30.7412)
hg.1.hg.3.2.0.conv1.stats tensor(32.4388)`

is there any problem with batch normalization or it is normal that the numbers are so low? what could be the problem?

Thank you

this is generally why we do fusion with these ops, its significantly faster and more accurate.

not sure why its not working, the number’s shouldn’t be that low though. my guess would be one of the two versions is continuing to collect statistics and update the mean/std which is making the results of your comparison diverge. Hard to tell without a full repro.

thank you,
do you have any suggestion on how to debug it?
also, what do you mean with one of the two versions is continuing to collect statistics ?

batchnorm is basically just doing a constant affine transform on the data i.e.

out=(activation-(some constant1))/(some constant2)

where hte two constants are chosen by calculating a running average of the mean/stddev of the activation over the training data. Normally during eval you want to freeze those constants so they don’t change over time.

My first theory is that you aren’t properly freezing the constants so that during eval, these constants change significantly while for the quantized model, they are hard coded to not change.

if its not that…

If i were to debug it, i would first start with the test: https://github.com/pytorch/pytorch/blob/main/test/quantization/core/test_quantized_module.py#L968 that verifies correctness for quantized batchnorm and try to reproduce it for your example for individual modules. If that doesn’t work i’d go back to basis and use a toy model that consists of a single batchnorm, calibrate on 2 random inputs, apply quantization and see what the SQNR is. it should be 30+. If that’s not the case, i’d then try to figure out what the running total mean/dev values you have before/after quantization either through direct inspection or by reverse engineering it based on the output.