Qnnpack accuracy very poor on unet model

I am using Unet model for semantic segmentation. I pass a batch of images to the model. The model is expected to output 0 or 1 for each pixel of the image (depending upon whether pixel is part of person object or not). 0 is for background, and 1 is for foreground.

I am trying to quantize the Unet model with Pytorch quantization apis (static quantization). The model accuracy is good for FBGEMM config. However, the model outputs all pixels as black pixels (background pixels) for all images. The model output is very high positive for background channel and very high negative for forground channel. (I perform softmax to get final classification of the pixel). The code works fine for FBGEMM. Same code produces very different results for QNNPACK. Is there anything missing for QNNPACK code? I am pasting my code below for reference.

# Static Quantization - FBGEMM/QNNPACK

import torch.quantization as Q
import torch

framework = 'qnnpack'   # change to fbgemm for x86 architecture
per_channel_quantized_model = model
per_channel_quantized_model.eval()
per_channel_quantized_model.fuse_model()
per_channel_quantized_model.qconfig = Q.get_default_qconfig(framework)
torch.backends.quantized.engine = framework

print("Preparing . . .")
Q.prepare(per_channel_quantized_model, inplace=True)

print("Running model . . .")
eval_model_for_quantization(per_channel_quantized_model, 'cpu')

print("Converting . . .")
Q.convert(per_channel_quantized_model, inplace=True)

print("***************")
print(per_channel_quantized_model)

print("***************")
print("Checking Accuracy . . .")
accuracy = eval_model_for_quantization(per_channel_quantized_model, 'cpu')
print()
print('Evaluation accuracy after quantization', accuracy)

One issue that I did encounter is that the upsampling layers of Unet use nn.ConvTranspose2d which is not supported for quantization. Hence before this layer, we need to dequantize tensors, apply nn.ConvTranspose2d, and then requantize for subsequent layers. Can this be reason for lower accuracy?

NOTE - I did try with QAT for QNNPACK. However, model output does not change i.e. it gives out all black pixels.

Here I have more details…

Following is the model output for some inputs. The LHS bracket is target value. Values in RHS are output values for 2 channels (background and foreground). The RHS values are passed to a Softmax function and the final result is obtained (…and to be compared with target)

Model output before quantization:

 ..........(1.0) (1.42 16.16)
 ..........(1.0) (-40.55 42.14)
 ..........(0.0) (15.20 -19.15)
 ..........(1.0) (-21.16 25.58)
 ..........(1.0) (-43.54 41.77)
 ..........(0.0) (19.74 -23.29)
 ..........(1.0) (-29.66 33.56)
 ..........(1.0) (1.23 -7.96)
 ..........(1.0) (-35.54 42.13)
 ..........(0.0) (16.74 -19.38)
 ..........(0.0) (9.40 -2.54)
 ..........(0.0) (21.67 -27.59)
 ..........(1.0) (-52.96 53.53)
 ..........(0.0) (18.02 -20.90)
 ..........(1.0) (-19.79 22.51)
 ..........(0.0) (13.33 -20.11)
 ..........(0.0) (29.95 -31.26)
 ..........(0.0) (23.35 -29.38)
 ..........(1.0) (-15.23 9.97)
 ..........(0.0) (18.14 -24.80)
 ..........(0.0) (19.13 -26.98)
 ..........(1.0) (-18.12 22.96)

Model output after FBGEMM quantization - as you can see below, the output values did change, but only to small extent

..........(1.0) (0.00 18.41)
 ..........(1.0) (-45.41 50.32)
 ..........(0.0) (13.50 -14.73)
 ..........(1.0) (-24.55 30.69)
 ..........(1.0) (-39.28 38.05)
 ..........(0.0) (13.50 -17.18)
 ..........(1.0) (-22.09 25.78)
 ..........(1.0) (2.45 -7.36)
 ..........(1.0) (-23.32 29.46)
 ..........(0.0) (17.18 -23.32)
 ..........(0.0) (12.27 -6.14)
 ..........(0.0) (20.87 -23.32)
 ..........(1.0) (-45.41 49.10)
 ..........(0.0) (15.96 -18.41)
 ..........(1.0) (-17.18 20.87)
 ..........(0.0) (11.05 -18.41)
 ..........(0.0) (27.00 -27.00)
 ..........(0.0) (17.18 -23.32)
 ..........(1.0) (-2.45 1.23)
 ..........(0.0) (15.96 -20.87)
 ..........(0.0) (18.41 -24.55)
 ..........(1.0) (-15.96 20.87)

Now look at following model output for QNNPACK quantization. The output is very different from the unquantized version. In particular, values for all pixels is positive for background channel and negative for foreground channel.

..........(1.0) (14.06 -17.12)
 ..........(1.0) (11.61 -16.51)
 ..........(0.0) (20.17 -25.06)
 ..........(1.0) (18.34 -22.01)
 ..........(1.0) (15.89 -14.06)
 ..........(0.0) (20.17 -25.67)
 ..........(1.0) (22.62 -29.34)
 ..........(1.0) (24.45 -28.73)
 ..........(1.0) (14.06 -20.17)
 ..........(0.0) (22.62 -28.12)
 ..........(0.0) (27.51 -20.17)
 ..........(0.0) (21.40 -23.84)
 ..........(1.0) (20.78 -29.34)
 ..........(0.0) (17.12 -23.23)
 ..........(1.0) (28.73 -31.18)
 ..........(0.0) (18.34 -20.78)
 ..........(0.0) (20.17 -23.84)
 ..........(0.0) (20.78 -23.84)
 ..........(1.0) (13.45 -16.51)
 ..........(0.0) (17.12 -21.40)
 ..........(0.0) (21.40 -25.67)
 ..........(1.0) (21.40 -26.90)

Any thoughts by anyone?

did you set the qengine to qnnpack before you evaluate the model? you can set the qengine with https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_quantized.py#L110

Hi Jerry - thanks for the reply.

Yes I have done this already. You can see following code line in my first post…

is the code okay? Am I doing anything wrong?

I see. quantization code looks correct. cc @supriyar @dskhudia could you take a look

Update - Earlier when I worked on QAT, I was using wrong config. After correcting it, the QAT helps improve the accuracy. However I am still interested in knowing (in case of static quantization for qnnpack config) why the output value for all pixels is positive for background channel and negative for foreground channel.

The model I am using is available here:
source - https://github.com/thuyngch/Human-Segmentation-PyTorch
model file - https://drive.google.com/file/d/17GZLCi_FHhWo4E4wPobbLAQdBZrlqVnF/view

Some models are more sensitive to quantization than others, and you can try selectively quantizing a part of the model, e.g. keep first conv unquantized, to mitigate the problem for post training static quantization. and typically qat will help in these cases as well.

we also have eager mode numeric suite to help debug which layer is more sensitive to quantization as well: https://pytorch.org/tutorials/prototype/numeric_suite_tutorial.html?highlight=transformer

we are landing https://github.com/pytorch/pytorch/pull/46077 which fixes a bug where some qnnpack activations had incorrect numerical values if the preceding layer was in NHWC. If your model used one of the activations patched by that PR, it might help.

Hi @amitdedhia,

I am also trying to use quantized UNet model (static quantization) for semantic segmentation (10 classes). I am facing issues with the fbgemm configuration itself as mentioned in the forum post. If possible, can you point out where I’m going wrong in my quantization? Meanwhile I’m trying to debug using the link shared by @Vasiliy_Kuznetsov