Convert FP32 model in torchvision.models to INT8 model

Hello

I’d like to convert fp32 model supported in torchvision.models to INT8 model to accelerate CPU inference.
As I understand, using prepare() and convert() can convert the model (https://pytorch.org/docs/stable/quantization.html#id1).

Any ideas to handle below error messages?
thanks a lot!

Environment:

Nvidia Jetson TX2
pytorch 1.4.0

My code:

model = torchvision.models.vgg16(pretrained=True).eval()
img = np.random.randint(255, size=(1,3,224,224), dtype=np.uint8)
img = torch.FloatTensor(img)#.cuda()


model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
config = torch.quantization.get_default_qat_qconfig('qnnpack')
torch.backends.quantized.engine = 'qnnpack'

model.qconfig = torch.quantization.default_qconfig
model = torch.quantization.prepare(model)
model = torch.quantization.convert(model)

model.eval()
quant = QuantStub()
img = quant(img)
for loop in range(100):
    start = time.time()
    output = model.forward(img)#, layer[1])
    _, predicted = torch.max(output, 1)
    end = time.time()

    print(end-start)

Result:

/home/user/.local/lib/python3.6/site-packages/torch/quantization/observer.py:172: UserWarning: Must run observer before calling calculate_qparams.                           Returning default scale and zero point.
  Returning default scale and zero point.")
Traceback (most recent call last):
  File "quan2.py", line 37, in <module>
    output = model.forward(img)#, layer[1])
  File "/usr/local/lib/python3.6/dist-packages/torchvision-0.5.0a0+85b8fbf-py3.6-linux-aarch64.egg/torchvision/models/vgg.py", line 43, in forward
    x = self.features(x)
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user/.local/lib/python3.6/site-packages/torch/nn/quantized/modules/conv.py", line 215, in forward
    self.dilation, self.groups, self.scale, self.zero_point)
RuntimeError: Could not run 'quantized::conv2d' with arguments from the 'CPUTensorId' backend. 'quantized::conv2d' is only available for these backends: [QuantizedCPUTensorId]. (dispatch_ at /home/user/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:257)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x78 (0x7f98d36258 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x145a9a8 (0x7f67d9c9a8 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #2: <unknown function> + 0x484b200 (0x7f6b18d200 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x6518f0 (0x7f9060e8f0 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x61a868 (0x7f905d7868 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x25ee04 (0x7f9021be04 in /home/user/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #7: python3() [0x529958]
frame #9: python3() [0x527860]
frame #11: python3() [0x5f2bcc]
frame #14: python3() [0x528ff0]
frame #17: python3() [0x5f2bcc]
frame #19: python3() [0x595e5c]
frame #21: python3() [0x529738]
frame #23: python3() [0x527860]
frame #25: python3() [0x5f2bcc]
frame #28: python3() [0x528ff0]
frame #31: python3() [0x5f2bcc]
frame #33: python3() [0x595e5c]
frame #35: python3() [0x529738]
frame #37: python3() [0x527860]
frame #38: python3() [0x5297dc]
frame #40: python3() [0x528ff0]
frame #45: __libc_start_main + 0xe0 (0x7f9a22d6e0 in /lib/aarch64-linux-gnu/libc.so.6)
frame #46: python3() [0x420e94]

There are a couple things might be related to the error:

  1. You are using original vgg16, you need some modifications so that it can be quantized. You can take a look at torchvision/models/quantization/resnet.py as well as the tutorial: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html to see how to modify the model.
  2. QuantStub() need to be added to the model instead of using it directly.
  3. Between prepare() and convert() you need to run the model to collect histogram, which will be used to determine the scale and zero point of the quantized model.

Also are you running it on GPU? PyTorch quantization currently only support CPU and mobile backend.

yes, i am only considering cpu operation. Thanks

Thanks a lot. By the way, in the case of torchvision/models/quantization/resnet.py, there is specified backend, “fbgemm”, On the other hand, in the case of torchvision/models/quantization/mobilenet.py, the specified backend is ‘qnnpack’. Because I am working on ARM processor, I have to exploit ‘qnnpack’ backend. Is there any different implementation methodology between ‘fbgemm’ and ‘qnnpack’?

The quantization workflow and methodology is the same for ‘fbgemm’ and ‘qnnpack’, they are just different backends for different platforms.

  1. Although I carefully read the attached link in your answer, it is hard to find ‘observer’ phase. Could you explain the observer phase? Which function in the attached link conduct the operating observer phase?

  2. Even if there is no batch norm layer, is fuse always necessary?

  1. Observer phase is called calibration in the tutorial link, basically it runs several iterations and us observer to collect the statistics of the activation and weight, which will be used to do quantization in convert():
# Calibrate with the training set
evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
  1. You need to fuse batch norm, can refer to this post: Static quantizing and batch norm error (could not run aten::native_batch_norm with args from QuantCPUTensorid backend')

Other than that fuse is not necessary but fuse will help on the performance and accuracy.

I see.

To summary what I understood, the quantization step is done as follow.

  1. Load pretrained fp32 model
  2. run prepare() to prepare converting pretrained fp32 model to int8 model
  3. run fp32model.forward() to calibrate fp32 model by operating the fp32 model for a sufficient number of times. However, this calibration phase is a kind of `blackbox’ process so I cannot notice that the calibration is actually done.
  4. run convert() to finally convert the calibrated model to usable int8 model.
1 Like