Quantization error on pre-trained model

I want to quantize my model to load on a rasberrry pi for fast inference and I am using a pre-trained model and am getting this error. How can I go about resolving this issue?
#Static quantization of torch model

model_tp = torch.hub.load('yangsenius/TransPose:main',
                          'tph_a4_256x192',
                          pretrained=True)
      
model_tp.final_layer =  torch.nn.Sequential(torch.nn.Conv2d(96, 16, kernel_size=1))                                    
model = model_tp.to(device)

x = torch.randn(1,3,256,192).to(device, dtype=float)
print(f"Output shape from model is: {model(x.float()).shape}")

quant_model = model

backend = "qnnpack"

quant_model.qconfig = torch.quantization.get_default_qconfig(backend)

torch.backends.quantized.engine = backend

model_static_quantized = torch.quantization.prepare(quant_model, inplace=False)

model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

NotImplementedError Traceback (most recent call last)

in ()
9 image = inp_img.squeeze(0).permute(1,2,0).cpu().detach().numpy()
10
—> 11 output = model_static_quantized(inp_img)
12 output = np.expand_dims(output[0].cpu().detach().numpy(),0)
13

3 frames

/usr/local/lib/python3.7/dist-packages/torch/nn/quantized/modules/conv.py in forward(self, input)
424 mode=self.padding_mode)
425 return ops.quantized.conv2d(
→ 426 input, self._packed_params, self.scale, self.zero_point)
427
428 @classmethod

NotImplementedError: Could not run ‘quantized::conv2d.new’ with arguments from the ‘CPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘quantized::conv2d.new’ is only available for these backends: [QuantizedCPU, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].

QuantizedCPU: registered at …/aten/src/ATen/native/quantized/cpu/qconv.cpp:883 [kernel]
BackendSelect: fallthrough registered at …/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at …/aten/src/ATen/core/PythonFallbackKernel.cpp:47 [backend fallback]
Named: registered at …/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at …/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at …/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ADInplaceOrView: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
AutogradLazy: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
AutogradXPU: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at …/aten/src/ATen/core/VariableFallbackKernel.cpp:59 [backend fallback]
Tracer: registered at …/torch/csrc/autograd/TraceTypeManual.cpp:291 [backend fallback]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at …/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
Autocast: fallthrough registered at …/aten/src/ATen/autocast_mode.cpp:305 [backend fallback]
Batched: registered at …/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at …/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

1 Like

my guess is that your model doesn’t have quant and dequant stubs so you are passing in an fp32 input to a quantized op that is expecting a quantized input.

See the example here: Quantization — PyTorch 1.12 documentation

After quantizing it looks like this. I don’t have the code for the model because i laoded it from torch. hub(). So how do i add quant and dequant stubs in this case

Short snippet only

TransPoseH(
  (conv1): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
  (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
  (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): QuantizedConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0, bias=False)
      (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0, bias=False)
      (bn3): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): QuantizedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0, bias=False)
        (1): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
)
model_tp = torch.hub.load('yangsenius/TransPose:main',
                          'tph_a4_256x192',
                          pretrained=True)
      
model_tp.final_layer =  torch.nn.Sequential(torch.nn.Conv2d(96, 16, kernel_size=1))                                    
model = model_tp.to(device)

x = torch.randn(1,3,256,192).to(device, dtype=float)
print(f"Output shape from model is: {model(x.float()).shape}")

# quant_model = model #change this to:
quant_model = torch.nn.Sequential(torch.ao.quantization.QuantStub(), model, torch.ao.quantization.DeQuantStub())

backend = "qnnpack"

quant_model.qconfig = torch.quantization.get_default_qconfig(backend)

torch.backends.quantized.engine = backend

model_static_quantized = torch.quantization.prepare(quant_model, inplace=False)

#### note: you should do calibration here, rather than immediately doing convert
# calibration_code()

model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
1 Like

Thanks for the code snippet it works fine till i convert the model to a quantized one. The problem arises when I pass in a sample input after converting the model through the model. Does the input image to the model need to be modified somehow? An example would be helpful. Thanks

When i convert i get this:

Sequential(
  (0): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (1): TransPoseH(
    (conv1): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1))
    (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1))
    (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
-----> Comment : (skipping the remainder of layers)

)
        (2): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): QuantizedLinear(in_features=96, out_features=96, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          )
          (linear1): QuantizedLinear(in_features=96, out_features=192, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): QuantizedLinear(in_features=192, out_features=96, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (norm1): QuantizedLayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (norm2): QuantizedLayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (3): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): QuantizedLinear(in_features=96, out_features=96, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          )
          (linear1): QuantizedLinear(in_features=96, out_features=192, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): QuantizedLinear(in_features=192, out_features=96, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (norm1): QuantizedLayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (norm2): QuantizedLayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (final_layer): Sequential(
      (0): QuantizedConv2d(96, 16, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
    )
  )
  (2): DeQuantize()
)
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:1109: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
 Returning default scale and zero point
x = torch.randn(1,3,192, 256)
model_static_quantized(x).shape

# I also tried this and it throws back the same error
# x = torch.quantize_per_tensor(x,
#                               0.1,
#                               10,
#                               dtype=torch.quint8)

  6 #                               dtype=torch.quint8)
  7 

----> 8 model_static_quantized(x).shape

7 frames

/root/.cache/torch/hub/yangsenius_TransPose_main/lib/models/transpose_h.py in forward(self, x)
99 residual = self.downsample(x)
100
→ 101 out += residual
102 out = self.relu(out)
103

NotImplementedError: Could not run ‘aten::add.out’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit Internal Login for possible resolutions. ‘aten::add.out’ is only available for these backends: [CPU, CUDA, Meta, MkldnnCPU, SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:18433 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:26496 [kernel]
Meta: registered at aten/src/ATen/RegisterMeta.cpp:12703 [kernel]
MkldnnCPU: registered at aten/src/ATen/RegisterMkldnnCPU.cpp:595 [kernel]
SparseCPU: registered at aten/src/ATen/RegisterSparseCPU.cpp:958 [kernel]
SparseCUDA: registered at aten/src/ATen/RegisterSparseCUDA.cpp:1060 [kernel]
SparseCsrCPU: registered at aten/src/ATen/RegisterSparseCsrCPU.cpp:221 [kernel]
SparseCsrCUDA: registered at aten/src/ATen/RegisterSparseCsrCUDA.cpp:248 [kernel]
BackendSelect: fallthrough registered at …/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at …/aten/src/ATen/core/PythonFallbackKernel.cpp:47 [backend fallback]
Named: fallthrough registered at …/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
Conjugate: registered at …/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at …/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ADInplaceOrView: registered at …/torch/csrc/autograd/generated/ADInplaceOrViewType_0.cpp:2505 [kernel]
AutogradOther: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradCPU: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradCUDA: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradXLA: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradLazy: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradXPU: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradMLC: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradHPU: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradNestedTensor: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradPrivateUse1: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradPrivateUse2: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
AutogradPrivateUse3: registered at …/torch/csrc/autograd/generated/VariableType_4.cpp:8932 [autograd kernel]
Tracer: registered at …/torch/csrc/autograd/generated/TraceType_4.cpp:9308 [kernel]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at …/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
Autocast: fallthrough registered at …/aten/src/ATen/autocast_mode.cpp:305 [backend fallback]
Batched: registered at …/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at …/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

You have a few options:

1: the eager mode quantization APIs work by taking modules that do normal torch.float (tensor) operations and converting them to versions that do torch.qint8/quint8 (quantized tensor) operations, whose values are *close* to the original torch.float ones.

If your model is doing something that isn’t a module though then it can’t be detected+converted by these APIs and if that something doesn’t support quantized tensors then you’ll run into the error you are now seeing.

In this case the aten::add.out operation is the issue i.e. basic tensor addition. You model has something like Z=X+Y which leaves the add operator without a module to detect/convert. The standard way to handle this is by initializing something like adder = torch.nn.quantized.functional.FloatFunctional() in the init and then calling something like Z=adder.add(X, Y) in the fwd (FloatFunctional is just a basic module that was created to ‘modularize’ a number of basic operations details: FloatFunctional — PyTorch 1.11.0 documentation). You’d need to do this for each add operation in each module you use so that there would be a unique FloatFunctional for each add operation in the model.

2: The easiest solution would be to use dynamic quantization, though it would also be the least performant. The specific issue occurs because the quantization method being used, i.e. static quantization, makes the entire model run using qint8/quint8 dtype activations, so when the add operation sees a qint8/quint8 dtype it doesn’t know what to do. Since dynamic quantization doesn’t change the activation types it will leave the activation dtype as float32 which would bypass the above concern.

Note that dynamic quantization wouldn’t require the QuantStub and DeQuantStub since their purpose is to convert to/from qint8/quint8 although leaving them in shouldn’t cause any problems either.

A tutorial for dynamic quantization on BERT can be found in our docs though it contains a lot of supporting code for BERT, the main line you’ll need can be found at:
https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html#apply-the-dynamic-quantization

3: instead of using the eager mode quantization APIs you could try the FX quantization APIs, those don’t require modularization so you wouldn’t run into the error above, but it have some other requirements. This would be the second easiest method, although i’m not sure how likely it is that it would work. The specific requirement is that your model is FX traceable which essentially means that it doesn’t have if statements. Details can be found in the main quantization docs page: Quantization — PyTorch 1.12 documentation

1 Like

I was trying to use post-training static quantization on torchvision pre-trained models, I came across similar issues. But, the code snippet by @HDCharles helped me resolve the issue. Here is the link to my colab with code snippets for applying post-training static quantization to a custom model and torchvision pre-trained model.
Post Training Static Quantization on Pre-trained Torchvision Models

1 Like

Hey Sairam954, thanks for sharing the colab code. I tried to apply the snippet earlier and i got the error because certain operations within the pre-trained model code are not supported by the PyTorch quantization library. Alex net in your case works probably due to the supported operations by the library that are well supported within the model architecture

1 Like

As mentioned in errors, quantization is not supported for GPU, CUDA backends. Try loading with CPU.