Quantizing only a part of the model

Hi, everyone!
I have a problem when I am doing quantization on a part of the model like below.
I had no error when I did quantization for the whole model.

import torch

class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(1, 1, 1)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):

        x = self.conv(x)
        x = self.relu(x)
        x = self.quant(x)
        x = self.conv2(x)
        x = self.dequant(x)
        x = self.relu(x)

        return x

model_fp32 = M()
model_fp32.eval()

model_fp32.conv2.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model_fp32)

input_fp32 = torch.randn(40, 1, 16, 16)
model_fp32_prepared(input_fp32) # calibration

model_int8 = torch.quantization.convert(model_fp32_prepared)
result = model_int8(input_fp32)

The code gives a runtime error at the last inference step.

/usr/local/lib/python3.6/dist-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
Traceback (most recent call last):
  File "/notebooks/PycharmProjects/NetworkCompression/pytorch_quantization_prac/main_2.py", line 33, in <module>
    result = model_int8(input_fp32)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/notebooks/PycharmProjects/NetworkCompression/pytorch_quantization_prac/main_2.py", line 17, in forward
    x = self.conv2(x)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/quantized/modules/conv.py", line 332, in forward
    input, self._packed_params, self.scale, self.zero_point)
RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:858 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradXLA: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
Tracer: fallthrough registered at /pytorch/torch/csrc/jit/frontend/tracer.cpp:967 [backend fallback]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]


Process finished with exit code 1

Hi @fred107, the reason it is not working is because model_fp32.quant does not have a qconfig specified. The quantization convert API only swaps modules with qconfig defined. You could fix this by doing something like model_fp32.quant.qconfig = torch.quantization.get_default_qconfig('fbgemm') before calling prepare.

1 Like

Thank you @Vasiliy_Kuznetsov . It works now !!
Then what about model_fp32.dequant.qconfig? It works out even when model_fp32.dequant.qconfig is not defined.

DequantStub() does nothing but just change the variable from quint to float32. It does not store any specific information, scale factor etc.

That’s why when trying to quantize multiple blocks, we need multiple instances of QuantStub(). But, single instance of DeQuantStub() can dequant them all.

2 Likes

cc @jerryzh168, is it expected that a DeQuantStub is swapped without a qconfig defined in Eager mode?

previously that is true but it’s fixed after [quant][eagermode][fix] Fix quantization for DeQuantStub by jerryzh168 · Pull Request #49428 · pytorch/pytorch · GitHub