Index out of bounds Error with PerChannel Quantization

Hello,

I have encountered this problem while trying to perform per-channel quantization on weights with ch_axis=1 quantization parameter. It causes the “index out of bounds error” when dimention of axis 1 of the weight tensor is smaller than dimention of axis 0(in the following example 100 is smaller than 110(note that 100 will be axis 1 in weight matrix)). If the axis 0 dimention is smaller(when changing 110 to 90) the error doesn’t occure. It is not reproducible with ch_axis=0 quantization parameter. With ch_axis=0 it doesn’t matter if one axis dimention is bigger or smaller then the other.
Here is a minimal example that fails:

import torch
import torch.nn as nn
from torch.quantization.observer import PerChannelMinMaxObserver, MinMaxObserver

class QTestNet1(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(100,110, bias=False)
        
    def set_qconfig(self):

        self.linear.qconfig = torch.quantization.QConfig(
            weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, ch_axis=1),
            activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine)
        )
    
    def forward(self, x):
        x = self.linear(x)
        return x
        
model = QTestNet1()
model.set_qconfig()
model_prepared = torch.quantization.prepare_qat(model, inplace=False)

input_x = torch.randn(1,100) 
model_prepared(input_x).shape # just checking that forward doesn't fail

model_int8 = torch.quantization.convert(model_prepared.eval()) # error

This is a full error:

/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/quantize.py in convert(module, mapping, inplace, remove_qconfig, is_reference, convert_custom_config_dict)
    519     _convert(
    520         module, mapping, inplace=True, is_reference=is_reference,
--> 521         convert_custom_config_dict=convert_custom_config_dict)
    522     if remove_qconfig:
    523         _remove_qconfig(module)

/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/quantize.py in _convert(module, mapping, inplace, is_reference, convert_custom_config_dict)
    557             _convert(mod, mapping, True,  # inplace
    558                      is_reference, convert_custom_config_dict)
--> 559         reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
    560 
    561     for key, value in reassign.items():

/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/quantize.py in swap_module(mod, mapping, custom_module_class_mapping)
    590                 new_mod = qmod.from_float(mod, weight_qparams)
    591             else:
--> 592                 new_mod = qmod.from_float(mod)
    593             swapped = True
    594 

/usr/local/lib/python3.7/dist-packages/torch/nn/quantized/modules/linear.py in from_float(cls, mod)
    271                       mod.out_features,
    272                       dtype=dtype)
--> 273         qlinear.set_weight_bias(qweight, mod.bias)
    274         qlinear.scale = float(act_scale)
    275         qlinear.zero_point = int(act_zp)

/usr/local/lib/python3.7/dist-packages/torch/nn/quantized/modules/linear.py in set_weight_bias(self, w, b)
    232 
    233     def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
--> 234         self._packed_params.set_weight_bias(w, b)
    235 
    236     @classmethod

/usr/local/lib/python3.7/dist-packages/torch/nn/quantized/modules/linear.py in set_weight_bias(self, weight, bias)
     25     def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
     26         if self.dtype == torch.qint8:
---> 27             self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
     28         elif self.dtype == torch.float16:
     29             self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)

/usr/local/lib/python3.7/dist-packages/torch/_ops.py in __call__(self, *args, **kwargs)
    141         # We save the function ptr as the `op` attribute on
    142         # OpOverloadPacket to access it here.
--> 143         return self._op(*args, **kwargs or {})
    144 
    145     # TODO: use this to make a __dir__

IndexError: select(): index 100 out of range for tensor of size [100] at dimension 0

I am using torch==1.12.1+cu113 from google colab.

I think it probably failed here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L76
I guess we probably need a better error message, prepacking only works for per channel quantized weight with ch_axis=0 I think.
ch_axis=1 for weight means the weight is quantized per channel in the input features dimension, I feel it will be hard to implement an efficient kernel for that since to compute the value for the output activation for each output channel, we are multiplying input values with weight values which are quantized with different quantization parameters.

Hello,

Thank you for your response. However, many things are still unclear to me.

  1. Why doesn’t it fail when the output feature axis (axis 0) is smaller? Does it still perform it on axis 0, just using indexation from axis 1?
  2. Why is ch_axis=1 is an option then?
  3. Isn’t multiplication of two matrices implemented as column by row (instead of row by column) which gives several 1-rank marices that are summed up together. I think that is the default implementation because it’s faster due to less registory memory loads(easier to load two vectors n times than reload columns of the second matrix from 1 to n for each row of the first matrix). In this implementation, it seems like we just have to sum up n matrices that are quantized with different parameters. Summing them cannot be implemented efficiently?

When observation is made, it is assumed that you are observing the output channels. For the linear layer, the output channel is always at axis 0. That basically means that observing axis 1 would not work from the observation point of view.

To answer your questions:

  1. Why doesn’t it fail when the output feature axis (axis 0) is smaller? Does it still perform it on axis 0, just using indexation from axis 1?

This is because the main assumption is that you are observing the output channels at axis 0. At this line it is checked what the dimension of the 0th axis, while at this line (and several others) that number is used to access the elements – hence the error.

  1. Why is ch_axis=1 is an option then?

This is an option for layers that have the output channels at a different location, such as ConvTranspose

  1. Isn’t multiplication of two matrices implemented as column by row (instead of row by column) which gives several 1-rank marices that are summed up together.

In theory that is correct, however, the exact implementation is dependent on the exact backend that is being used. Is your question about how the mm is implemented in theory or is it about how it could is implemented in pytorch specifically?

So how to quantize per_channel? When I set default params for PerChannelMinMaxObserver(), it just provides 1 scaler and 1 zero-point for the whole tensor.

(m): ModuleList(                                                                                                                                                              
        (0): Bottleneck(                                                                                                                                                            
          (cv1): Q_Conv(                                                                                                                                                            
            (quant): Quantize(scale=tensor([0.0938]), zero_point=tensor([0]), dtype=torch.quint8)                                                                                   
            (conv): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.08255661278963089, zero_point=0, padding=(1, 1))                                       
            (bn): Identity()                                                                                                                                                        
            (act): Identity()                                                                                                                                                       
            (dequant): DeQuantize()                                                                                                                                                 
          )                                                                                                                                                                         
          (cv2): Q_Conv(                                                                                                                                                            
            (quant): Quantize(scale=tensor([0.0710]), zero_point=tensor([0]), dtype=torch.quint8)                                                                                   
            (conv): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.10519054532051086, zero_point=0, padding=(1, 1))                                       
            (bn): Identity()                                                                                                                                                        
            (act): Identity()                                                                                                                                                       
            (dequant): DeQuantize()                                                                                                                                                 
          )                                                                                                                                                                         
        )                                                                                                                                                                           
      ) 

Besides, although the default qscheme is per_channel_affine is set, the zero-point when quantizing model is still 0.
Are there any solutions for this problem?

in general, no.

The issue is that while the quantization tools look like they are fully composable, they are limited by the available kernels and backends.

firstly, we don’t have any kernels where activations are quantized per channel. So that quantize op which quantizes activations, it doesn’t matter how you set up the qconfig, you’re not going to get a coherent resulting op with activations that get quantized per channel. Same with the QuantizedConvRelu2d, the scale/zero_point are the output activation quantization scale and zero point, which are never per-channel. Now what you CAN quantize per-channel is the weights, but calling QuantizedConvReLU2d.weight won’t work either because there is no weight attribute, the weights are packed into a special form. But if you look at the .weight() method you can unpack the weight and look at it.

so to be clear, the per-channel stuff is supposed to be for the weights which don’t show up in the same way as the activation quantization qparams which seems to be part of the confusion.

1 Like

Oh, I get it. Thank you.